Skip to content

Commit a350712

Browse files
feat: add configurable shift parameter for Z-Image (#9004)
* feat: add configurable shift parameter for Z-Image sigma schedule Add a shift (mu) override to the Z-Image denoise invocation and expose it in the UI. When left blank, shift is auto-calculated from image dimensions (existing behavior). Users can override to fine-tune the timestep schedule, with an inline X button to reset back to auto. * refactor: switch Z-Image sigma schedule from exponential to linear time shift Use shift directly as a linear multiplier instead of exp(mu), giving more predictable and uniform control over the timestep schedule. Auto-calculated values are converted via exp(mu) to preserve identical default behavior. * feat: recall Z-Image shift parameter from metadata Write z_image_shift into graph metadata and add a ZImageShift recall handler so the shift override can be restored from previously generated images. Auto-mode (null) is omitted from metadata to avoid persisting a stale value. --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 3c9b282 commit a350712

9 files changed

Lines changed: 150 additions & 15 deletions

File tree

invokeai/app/invocations/z_image_denoise.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
title="Denoise - Z-Image",
5151
tags=["image", "z-image"],
5252
category="image",
53-
version="1.4.0",
53+
version="1.5.0",
5454
classification=Classification.Prototype,
5555
)
5656
class ZImageDenoiseInvocation(BaseInvocation):
@@ -104,6 +104,15 @@ class ZImageDenoiseInvocation(BaseInvocation):
104104
description=FieldDescriptions.vae + " Required for control conditioning.",
105105
input=Input.Connection,
106106
)
107+
# Shift override for the sigma schedule. If None, shift is auto-calculated from image dimensions.
108+
shift: Optional[float] = InputField(
109+
default=None,
110+
ge=0.0,
111+
description="Override the timestep shift (mu) for the sigma schedule. "
112+
"Leave blank to auto-calculate based on image dimensions (recommended). "
113+
"Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.",
114+
title="Shift",
115+
)
107116
# Scheduler selection for the denoising process
108117
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
109118
default="euler",
@@ -225,34 +234,36 @@ def _calculate_shift(
225234
"""Calculate timestep shift based on image sequence length.
226235
227236
Based on diffusers ZImagePipeline.calculate_shift method.
237+
Returns a linear shift value (exp(mu) from the original formula).
228238
"""
239+
import math
240+
229241
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
230242
b = base_shift - m * base_image_seq_len
231243
mu = image_seq_len * m + b
232-
return mu
244+
# Convert from exponential mu to linear shift value
245+
return math.exp(mu)
233246

234-
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
235-
"""Generate sigma schedule with time shift.
247+
def _get_sigmas(self, shift: float, num_steps: int) -> list[float]:
248+
"""Generate sigma schedule with linear time shift.
236249
237-
Based on FlowMatchEulerDiscreteScheduler with shift.
250+
Uses linear time shift: shift / (shift + (1/t - 1)).
251+
The shift value is used directly as a multiplier.
238252
Generates num_steps + 1 sigma values (including terminal 0.0).
239253
"""
240-
import math
241254

242-
def time_shift(mu: float, sigma: float, t: float) -> float:
243-
"""Apply time shift to a single timestep value."""
255+
def time_shift(shift: float, t: float) -> float:
256+
"""Apply linear time shift to a single timestep value."""
244257
if t <= 0:
245258
return 0.0
246259
if t >= 1:
247260
return 1.0
248-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
261+
return shift / (shift + (1 / t - 1))
249262

250-
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
251-
# then apply time shift
252263
sigmas = []
253264
for i in range(num_steps + 1):
254265
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
255-
sigma = time_shift(mu, 1.0, t)
266+
sigma = time_shift(shift, t)
256267
sigmas.append(sigma)
257268

258269
return sigmas
@@ -313,11 +324,14 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
313324
# Concatenate all negative embeddings
314325
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
315326

316-
# Calculate shift based on image sequence length
317-
mu = self._calculate_shift(img_seq_len)
327+
# Calculate shift based on image sequence length, or use override
328+
if self.shift is not None:
329+
shift = self.shift
330+
else:
331+
shift = self._calculate_shift(img_seq_len)
318332

319333
# Generate sigma schedule with time shift
320-
sigmas = self._get_sigmas(mu, self.steps)
334+
sigmas = self._get_sigmas(shift, self.steps)
321335

322336
# Apply denoising_start and denoising_end clipping
323337
if self.denoising_start > 0 or self.denoising_end < 1:

invokeai/frontend/web/public/locales/en.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@
10011001
"seedVarianceEnabled": "Seed Variance Enabled",
10021002
"seedVarianceStrength": "Seed Variance Strength",
10031003
"seedVarianceRandomizePercent": "Seed Variance Randomize %",
1004+
"zImageShift": "Z-Image Shift",
10041005
"seed": "Seed",
10051006
"steps": "Steps",
10061007
"strength": "Image to image strength",

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ const slice = createSlice({
8686
setZImageScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
8787
state.zImageScheduler = action.payload;
8888
},
89+
setZImageShift: (state, action: PayloadAction<number | null>) => {
90+
state.zImageShift = action.payload;
91+
},
8992
setZImageSeedVarianceEnabled: (state, action: PayloadAction<boolean>) => {
9093
state.zImageSeedVarianceEnabled = action.payload;
9194
},
@@ -535,6 +538,7 @@ export const {
535538
setFluxDypeScale,
536539
setFluxDypeExponent,
537540
setZImageScheduler,
541+
setZImageShift,
538542
setZImageSeedVarianceEnabled,
539543
setZImageSeedVarianceStrength,
540544
setZImageSeedVarianceRandomizePercent,
@@ -696,6 +700,7 @@ export const selectFluxDypePreset = createParamsSelector((params) => params.flux
696700
export const selectFluxDypeScale = createParamsSelector((params) => params.fluxDypeScale);
697701
export const selectFluxDypeExponent = createParamsSelector((params) => params.fluxDypeExponent);
698702
export const selectZImageScheduler = createParamsSelector((params) => params.zImageScheduler);
703+
export const selectZImageShift = createParamsSelector((params) => params.zImageShift);
699704
export const selectZImageSeedVarianceEnabled = createParamsSelector((params) => params.zImageSeedVarianceEnabled);
700705
export const selectZImageSeedVarianceStrength = createParamsSelector((params) => params.zImageSeedVarianceStrength);
701706
export const selectZImageSeedVarianceRandomizePercent = createParamsSelector(

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ export const zParamsState = z.object({
717717
fluxDypeScale: zParameterFluxDypeScale,
718718
fluxDypeExponent: zParameterFluxDypeExponent,
719719
zImageScheduler: zParameterZImageScheduler,
720+
zImageShift: z.number().min(0).max(3).nullable(),
720721
upscaleScheduler: zParameterScheduler,
721722
upscaleCfgScale: zParameterCFGScale,
722723
seed: zParameterSeed,
@@ -788,6 +789,7 @@ export const getInitialParamsState = (): ParamsState => ({
788789
fluxDypeScale: 2.0,
789790
fluxDypeExponent: 2.0,
790791
zImageScheduler: 'euler',
792+
zImageShift: null,
791793
upscaleScheduler: 'kdpm_2',
792794
upscaleCfgScale: 2,
793795
seed: 0,

invokeai/frontend/web/src/features/metadata/parsing.tsx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import {
4343
setZImageSeedVarianceEnabled,
4444
setZImageSeedVarianceRandomizePercent,
4545
setZImageSeedVarianceStrength,
46+
setZImageShift,
4647
vaeSelected,
4748
widthChanged,
4849
zImageQwen3EncoderModelSelected,
@@ -686,6 +687,24 @@ const ZImageSeedVarianceRandomizePercent: SingleMetadataHandler<number> = {
686687
};
687688
//#endregion ZImageSeedVarianceRandomizePercent
688689

690+
//#region ZImageShift
691+
const ZImageShift: SingleMetadataHandler<number> = {
692+
[SingleMetadataKey]: true,
693+
type: 'ZImageShift',
694+
parse: (metadata, _store) => {
695+
const raw = getProperty(metadata, 'z_image_shift');
696+
const parsed = z.number().min(0).max(3).parse(raw);
697+
return Promise.resolve(parsed);
698+
},
699+
recall: (value, store) => {
700+
store.dispatch(setZImageShift(value));
701+
},
702+
i18nKey: 'metadata.zImageShift',
703+
LabelComponent: MetadataLabel,
704+
ValueComponent: ({ value }: SingleMetadataValueProps<number>) => <MetadataPrimitiveValue value={value} />,
705+
};
706+
//#endregion ZImageShift
707+
689708
//#region RefinerModel
690709
const RefinerModel: SingleMetadataHandler<ParameterSDXLRefinerModel> = {
691710
[SingleMetadataKey]: true,
@@ -1314,6 +1333,7 @@ export const ImageMetadataHandlers = {
13141333
ZImageSeedVarianceEnabled,
13151334
ZImageSeedVarianceStrength,
13161335
ZImageSeedVarianceRandomizePercent,
1336+
ZImageShift,
13171337
LoRAs,
13181338
CanvasLayers,
13191339
RefImages,

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildZImageGraph.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
selectZImageSeedVarianceEnabled,
99
selectZImageSeedVarianceRandomizePercent,
1010
selectZImageSeedVarianceStrength,
11+
selectZImageShift,
1112
selectZImageVaeModel,
1213
} from 'features/controlLayers/store/paramsSlice';
1314
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
@@ -58,6 +59,9 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
5859
// (1.0 means no CFG effect, matching FLUX convention)
5960
const { cfgScale: guidance_scale, steps, zImageScheduler } = params;
6061

62+
// Shift override (null = auto-calculate from image dimensions)
63+
const zImageShift = selectZImageShift(state);
64+
6165
// Seed Variance Enhancer settings
6266
const seedVarianceEnabled = selectZImageSeedVarianceEnabled(state);
6367
const seedVarianceStrength = selectZImageSeedVarianceStrength(state);
@@ -122,6 +126,7 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
122126
guidance_scale,
123127
steps,
124128
scheduler: zImageScheduler,
129+
shift: zImageShift ?? undefined,
125130
});
126131
const l2i = g.addNode({
127132
type: 'z_image_l2i',
@@ -216,6 +221,7 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
216221
z_image_seed_variance_enabled: seedVarianceEnabled,
217222
z_image_seed_variance_strength: seedVarianceStrength,
218223
z_image_seed_variance_randomize_percent: seedVarianceRandomizePercent,
224+
z_image_shift: zImageShift ?? undefined,
219225
});
220226
g.addEdgeToMetadata(seed, 'value', 'seed');
221227
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
2+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
3+
import { selectZImageShift, setZImageShift } from 'features/controlLayers/store/paramsSlice';
4+
import type React from 'react';
5+
import { memo, useCallback } from 'react';
6+
import { PiXBold } from 'react-icons/pi';
7+
8+
const CONSTRAINTS = {
9+
initial: 3,
10+
sliderMin: 1,
11+
sliderMax: 7,
12+
numberInputMin: 0,
13+
numberInputMax: 10,
14+
fineStep: 0.1,
15+
coarseStep: 0.5,
16+
};
17+
18+
const MARKS = [1, 2, 3, 4, 5, 6, 7];
19+
20+
const ParamZImageShift = () => {
21+
const shift = useAppSelector(selectZImageShift);
22+
const dispatch = useAppDispatch();
23+
24+
const onChange = useCallback((v: number) => dispatch(setZImageShift(v)), [dispatch]);
25+
const onReset = useCallback(
26+
(e: React.MouseEvent) => {
27+
e.preventDefault();
28+
e.stopPropagation();
29+
dispatch(setZImageShift(null));
30+
},
31+
[dispatch]
32+
);
33+
34+
const displayValue = shift ?? CONSTRAINTS.initial;
35+
36+
return (
37+
<FormControl>
38+
<FormLabel>
39+
Shift{' '}
40+
{shift !== null ? (
41+
<Text as="span" cursor="pointer" onClick={onReset} display="inline-flex" verticalAlign="middle">
42+
<PiXBold />
43+
</Text>
44+
) : (
45+
<Text as="span" opacity={0.5} fontWeight="normal" fontSize="xs">
46+
(auto)
47+
</Text>
48+
)}
49+
</FormLabel>
50+
<CompositeSlider
51+
value={displayValue}
52+
defaultValue={CONSTRAINTS.initial}
53+
min={CONSTRAINTS.sliderMin}
54+
max={CONSTRAINTS.sliderMax}
55+
step={CONSTRAINTS.coarseStep}
56+
fineStep={CONSTRAINTS.fineStep}
57+
onChange={onChange}
58+
marks={MARKS}
59+
/>
60+
<CompositeNumberInput
61+
value={displayValue}
62+
defaultValue={CONSTRAINTS.initial}
63+
min={CONSTRAINTS.numberInputMin}
64+
max={CONSTRAINTS.numberInputMax}
65+
step={CONSTRAINTS.coarseStep}
66+
fineStep={CONSTRAINTS.fineStep}
67+
onChange={onChange}
68+
/>
69+
</FormControl>
70+
);
71+
};
72+
73+
export default memo(ParamZImageShift);

invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ParamGuidance from 'features/parameters/components/Core/ParamGuidance';
2525
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
2626
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
2727
import ParamZImageScheduler from 'features/parameters/components/Core/ParamZImageScheduler';
28+
import ParamZImageShift from 'features/parameters/components/Core/ParamZImageShift';
2829
import ParamZImageSeedVarianceSettings from 'features/parameters/components/SeedVariance/ParamZImageSeedVarianceSettings';
2930
import { MainModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker';
3031
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
@@ -92,6 +93,7 @@ export const GenerationSettingsAccordion = memo(() => {
9293
<ParamSteps />
9394
{(isFLUX || isFlux2) && modelConfig && !isFluxFillMainModelModelConfig(modelConfig) && <ParamGuidance />}
9495
{!isFLUX && !isFlux2 && <ParamCFGScale />}
96+
{isZImage && <ParamZImageShift />}
9597
{isFLUX && <ParamFluxDypePreset />}
9698
{isFLUX && fluxDypePreset === 'manual' && <ParamFluxDypeScale />}
9799
{isFLUX && fluxDypePreset === 'manual' && <ParamFluxDypeExponent />}

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29072,6 +29072,12 @@ export type components = {
2907229072
* @default null
2907329073
*/
2907429074
vae?: components["schemas"]["VAEField"] | null;
29075+
/**
29076+
* Shift
29077+
* @description Override the timestep shift (mu) for the sigma schedule. Leave blank to auto-calculate based on image dimensions (recommended). Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.
29078+
* @default null
29079+
*/
29080+
shift?: number | null;
2907529081
/**
2907629082
* Scheduler
2907729083
* @description Scheduler (sampler) for the denoising process. Euler is the default and recommended. Heun is 2nd-order (better quality, 2x slower). LCM works with Turbo only (not Base).
@@ -29199,6 +29205,12 @@ export type components = {
2919929205
* @default null
2920029206
*/
2920129207
vae?: components["schemas"]["VAEField"] | null;
29208+
/**
29209+
* Shift
29210+
* @description Override the timestep shift (mu) for the sigma schedule. Leave blank to auto-calculate based on image dimensions (recommended). Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.
29211+
* @default null
29212+
*/
29213+
shift?: number | null;
2920229214
/**
2920329215
* Scheduler
2920429216
* @description Scheduler (sampler) for the denoising process. Euler is the default and recommended. Heun is 2nd-order (better quality, 2x slower). LCM works with Turbo only (not Base).

0 commit comments

Comments
 (0)