Skip to content

Commit 6fb1496

Browse files
committed
add universal noise and optional denoiser noise inputs
1 parent 33ec16d commit 6fb1496

12 files changed

Lines changed: 726 additions & 846 deletions

File tree

docs/contributing/NEW_MODEL_INTEGRATION.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,25 @@ export const NewModelSchedulerSelect = () => {
12091209
- [ ] Frontend UI component
12101210
- [ ] State management
12111211

1212+
**External Noise:**
1213+
- [ ] Add optional `noise: LatentsField` input to the denoise invocation
1214+
- [ ] Validate external noise shape against the architecture's expected
1215+
latent shape
1216+
- [ ] Preserve existing behavior when `noise` is not connected
1217+
- [ ] Extend `Universal Noise` when the architecture's latent noise contract
1218+
can be represented there
1219+
- [ ] Add a dedicated architecture-compatible noise invocation only when
1220+
`Universal Noise` cannot support the architecture cleanly
1221+
1222+
If your model supports external noise, the denoise invocation should accept
1223+
it as an optional input rather than replacing the existing seed-driven path.
1224+
When possible, wire the architecture into `Universal Noise` instead of
1225+
creating a separate noise node. Only create a dedicated noise invocation if
1226+
the architecture has a noise tensor contract that `Universal Noise` cannot
1227+
express cleanly. When external noise is connected, validate rank, channel
1228+
count, and spatial shape before blending it with init latents or using it as
1229+
the initial latent state.
1230+
12121231
---
12131232

12141233
## Summary: Minimal Integration

invokeai/app/invocations/anima_denoise.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from invokeai.app.invocations.model import TransformerField
4040
from invokeai.app.invocations.primitives import LatentsOutput
41+
from invokeai.app.invocations.universal_noise import validate_noise_tensor_shape
4142
from invokeai.app.services.shared.invocation_context import InvocationContext
4243
from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting
4344
from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning
@@ -165,7 +166,7 @@ def merge_intermediate_latents_with_init_latents(
165166
title="Denoise - Anima",
166167
tags=["image", "anima"],
167168
category="image",
168-
version="1.2.0",
169+
version="1.3.0",
169170
classification=Classification.Prototype,
170171
)
171172
class AnimaDenoiseInvocation(BaseInvocation):
@@ -181,6 +182,9 @@ class AnimaDenoiseInvocation(BaseInvocation):
181182
latents: Optional[LatentsField] = InputField(
182183
default=None, description=FieldDescriptions.latents, input=Input.Connection
183184
)
185+
noise: Optional[LatentsField] = InputField(
186+
default=None, description=FieldDescriptions.noise, input=Input.Connection
187+
)
184188
# denoise_mask is used for inpainting. Only the masked region is modified.
185189
denoise_mask: Optional[DenoiseMaskField] = InputField(
186190
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
@@ -459,7 +463,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
459463
init_latents = init_latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
460464

461465
# Generate initial noise (3D latent: [B, C, T, H, W])
462-
noise = self._get_noise(self.height, self.width, inference_dtype, device, self.seed)
466+
noise = self._prepare_noise_tensor(context, inference_dtype, device)
463467

464468
# Prepare input latents
465469
if init_latents is not None:
@@ -696,6 +700,16 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor
696700
# Remove temporal dimension for output: [B, C, 1, H, W] -> [B, C, H, W]
697701
return latents.squeeze(2)
698702

703+
def _prepare_noise_tensor(
704+
self, context: InvocationContext, inference_dtype: torch.dtype, device: torch.device
705+
) -> torch.Tensor:
706+
if self.noise is not None:
707+
noise = context.tensors.load(self.noise.latents_name).to(device=device, dtype=inference_dtype)
708+
validate_noise_tensor_shape(noise, "Anima", self.width, self.height)
709+
return noise
710+
711+
return self._get_noise(self.height, self.width, inference_dtype, device, self.seed)
712+
699713
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
700714
def step_callback(state: PipelineIntermediateState) -> None:
701715
context.util.sd_step_callback(state, BaseModelType.Anima)

invokeai/app/invocations/cogview4_denoise.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from invokeai.app.invocations.model import TransformerField
2222
from invokeai.app.invocations.primitives import LatentsOutput
23+
from invokeai.app.invocations.universal_noise import validate_noise_tensor_shape
2324
from invokeai.app.services.shared.invocation_context import InvocationContext
2425
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
2526
from invokeai.backend.model_manager.taxonomy import BaseModelType
@@ -34,7 +35,7 @@
3435
title="Denoise - CogView4",
3536
tags=["image", "cogview4"],
3637
category="image",
37-
version="1.0.0",
38+
version="1.1.0",
3839
classification=Classification.Prototype,
3940
)
4041
class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -44,6 +45,9 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
4445
latents: Optional[LatentsField] = InputField(
4546
default=None, description=FieldDescriptions.latents, input=Input.Connection
4647
)
48+
noise: Optional[LatentsField] = InputField(
49+
default=None, description=FieldDescriptions.noise, input=Input.Connection
50+
)
4751
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
4852
denoise_mask: Optional[DenoiseMaskField] = InputField(
4953
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
@@ -245,15 +249,7 @@ def _run_diffusion(
245249
# Generate initial latent noise.
246250
num_channels_latents = transformer_info.model.config.in_channels # type: ignore
247251
assert isinstance(num_channels_latents, int)
248-
noise = self._get_noise(
249-
batch_size=1,
250-
num_channels_latents=num_channels_latents,
251-
height=self.height,
252-
width=self.width,
253-
dtype=inference_dtype,
254-
device=device,
255-
seed=self.seed,
256-
)
252+
noise = self._prepare_noise_tensor(context, num_channels_latents, inference_dtype, device)
257253

258254
# Prepare input latent image.
259255
if init_latents is not None:
@@ -356,6 +352,24 @@ def _run_diffusion(
356352

357353
return latents
358354

355+
def _prepare_noise_tensor(
356+
self, context: InvocationContext, num_channels_latents: int, inference_dtype: torch.dtype, device: torch.device
357+
) -> torch.Tensor:
358+
if self.noise is not None:
359+
noise = context.tensors.load(self.noise.latents_name).to(device=device, dtype=inference_dtype)
360+
validate_noise_tensor_shape(noise, "CogView4", self.width, self.height, num_channels=num_channels_latents)
361+
return noise
362+
363+
return self._get_noise(
364+
batch_size=1,
365+
num_channels_latents=num_channels_latents,
366+
height=self.height,
367+
width=self.width,
368+
dtype=inference_dtype,
369+
device=device,
370+
seed=self.seed,
371+
)
372+
359373
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
360374
def step_callback(state: PipelineIntermediateState) -> None:
361375
context.util.sd_step_callback(state, BaseModelType.CogView4)

invokeai/app/invocations/flux2_denoise.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from invokeai.app.invocations.model import TransformerField, VAEField
2525
from invokeai.app.invocations.primitives import LatentsOutput
26+
from invokeai.app.invocations.universal_noise import validate_noise_tensor_shape
2627
from invokeai.app.services.shared.invocation_context import InvocationContext
2728
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
2829
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
@@ -54,7 +55,7 @@
5455
title="FLUX2 Denoise",
5556
tags=["image", "flux", "flux2", "klein", "denoise"],
5657
category="image",
57-
version="1.4.0",
58+
version="1.5.0",
5859
classification=Classification.Prototype,
5960
)
6061
class Flux2DenoiseInvocation(BaseInvocation):
@@ -69,6 +70,11 @@ class Flux2DenoiseInvocation(BaseInvocation):
6970
description=FieldDescriptions.latents,
7071
input=Input.Connection,
7172
)
73+
noise: Optional[LatentsField] = InputField(
74+
default=None,
75+
description=FieldDescriptions.noise,
76+
input=Input.Connection,
77+
)
7278
denoise_mask: Optional[DenoiseMaskField] = InputField(
7379
default=None,
7480
description=FieldDescriptions.denoise_mask,
@@ -240,14 +246,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
240246
init_latents = init_latents.to(device=device, dtype=inference_dtype)
241247

242248
# Prepare input noise (FLUX.2 uses 32 channels)
243-
noise = get_noise_flux2(
244-
num_samples=1,
245-
height=self.height,
246-
width=self.width,
247-
device=device,
248-
dtype=inference_dtype,
249-
seed=self.seed,
250-
)
249+
noise = self._prepare_noise_tensor(context, inference_dtype, device)
251250
b, _c, latent_h, latent_w = noise.shape
252251
packed_h = latent_h // 2
253252
packed_w = latent_w // 2
@@ -486,6 +485,23 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
486485
x = unpack_flux2(x.float(), self.height, self.width)
487486
return x
488487

488+
def _prepare_noise_tensor(
489+
self, context: InvocationContext, inference_dtype: torch.dtype, device: torch.device
490+
) -> torch.Tensor:
491+
if self.noise is not None:
492+
noise = context.tensors.load(self.noise.latents_name).to(device=device, dtype=inference_dtype)
493+
validate_noise_tensor_shape(noise, "FLUX.2", self.width, self.height)
494+
return noise
495+
496+
return get_noise_flux2(
497+
num_samples=1,
498+
height=self.height,
499+
width=self.width,
500+
device=device,
501+
dtype=inference_dtype,
502+
seed=self.seed,
503+
)
504+
489505
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
490506
"""Prepare the inpaint mask."""
491507
if self.denoise_mask is None:

invokeai/app/invocations/flux_denoise.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from invokeai.app.invocations.ip_adapter import IPAdapterField
2929
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
3030
from invokeai.app.invocations.primitives import LatentsOutput
31+
from invokeai.app.invocations.universal_noise import validate_noise_tensor_shape
3132
from invokeai.app.services.shared.invocation_context import InvocationContext
3233
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
3334
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
@@ -71,7 +72,7 @@
7172
title="FLUX Denoise",
7273
tags=["image", "flux"],
7374
category="image",
74-
version="4.5.1",
75+
version="4.6.0",
7576
)
7677
class FluxDenoiseInvocation(BaseInvocation):
7778
"""Run denoising process with a FLUX transformer model."""
@@ -82,6 +83,11 @@ class FluxDenoiseInvocation(BaseInvocation):
8283
description=FieldDescriptions.latents,
8384
input=Input.Connection,
8485
)
86+
noise: Optional[LatentsField] = InputField(
87+
default=None,
88+
description=FieldDescriptions.noise,
89+
input=Input.Connection,
90+
)
8591
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
8692
denoise_mask: Optional[DenoiseMaskField] = InputField(
8793
default=None,
@@ -211,21 +217,15 @@ def _run_diffusion(
211217
context: InvocationContext,
212218
):
213219
inference_dtype = torch.bfloat16
220+
device = TorchDevice.choose_torch_device()
214221

215222
# Load the input latents, if provided.
216223
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
217224
if init_latents is not None:
218-
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
225+
init_latents = init_latents.to(device=device, dtype=inference_dtype)
219226

220227
# Prepare input noise.
221-
noise = get_noise(
222-
num_samples=1,
223-
height=self.height,
224-
width=self.width,
225-
device=TorchDevice.choose_torch_device(),
226-
dtype=inference_dtype,
227-
seed=self.seed,
228-
)
228+
noise = self._prepare_noise_tensor(context, inference_dtype, device)
229229
b, _c, latent_h, latent_w = noise.shape
230230
packed_h = latent_h // 2
231231
packed_w = latent_w // 2
@@ -237,7 +237,7 @@ def _run_diffusion(
237237
packed_height=packed_h,
238238
packed_width=packed_w,
239239
dtype=inference_dtype,
240-
device=TorchDevice.choose_torch_device(),
240+
device=device,
241241
)
242242
neg_text_conditionings: list[FluxTextConditioning] | None = None
243243
if self.negative_text_conditioning is not None:
@@ -247,14 +247,14 @@ def _run_diffusion(
247247
packed_height=packed_h,
248248
packed_width=packed_w,
249249
dtype=inference_dtype,
250-
device=TorchDevice.choose_torch_device(),
250+
device=device,
251251
)
252252
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
253253
context=context,
254254
redux_cond_field=self.redux_conditioning,
255255
packed_height=packed_h,
256256
packed_width=packed_w,
257-
device=TorchDevice.choose_torch_device(),
257+
device=device,
258258
dtype=inference_dtype,
259259
)
260260
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
@@ -331,9 +331,7 @@ def _run_diffusion(
331331
img_cond: torch.Tensor | None = None
332332
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
333333
if is_flux_fill:
334-
img_cond = self._prep_flux_fill_img_cond(
335-
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
336-
)
334+
img_cond = self._prep_flux_fill_img_cond(context, device=device, dtype=inference_dtype)
337335
else:
338336
if self.fill_conditioning is not None:
339337
raise ValueError("fill_conditioning was provided, but the model is not a FLUX Fill model.")
@@ -391,7 +389,7 @@ def _run_diffusion(
391389
if isinstance(self.kontext_conditioning, list)
392390
else [self.kontext_conditioning],
393391
vae_field=self.controlnet_vae,
394-
device=TorchDevice.choose_torch_device(),
392+
device=device,
395393
dtype=inference_dtype,
396394
)
397395

@@ -508,6 +506,23 @@ def _run_diffusion(
508506
x = unpack(x.float(), self.height, self.width)
509507
return x
510508

509+
def _prepare_noise_tensor(
510+
self, context: InvocationContext, inference_dtype: torch.dtype, device: torch.device
511+
) -> torch.Tensor:
512+
if self.noise is not None:
513+
noise = context.tensors.load(self.noise.latents_name).to(device=device, dtype=inference_dtype)
514+
validate_noise_tensor_shape(noise, "FLUX", self.width, self.height)
515+
return noise
516+
517+
return get_noise(
518+
num_samples=1,
519+
height=self.height,
520+
width=self.width,
521+
device=device,
522+
dtype=inference_dtype,
523+
seed=self.seed,
524+
)
525+
511526
def _load_text_conditioning(
512527
self,
513528
context: InvocationContext,

invokeai/app/invocations/metadata_linked.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def _loras_to_json(obj: Union[Any, list[Any]]):
717717
md.update({"denoising_start": self.denoising_start})
718718
md.update({"denoising_end": self.denoising_end})
719719
md.update({"model": self.transformer.transformer})
720-
md.update({"seed": self.seed})
720+
md.update({"seed": self.noise.seed if self.noise is not None and self.noise.seed is not None else self.seed})
721721
md.update({"cfg_scale": self.cfg_scale})
722722
md.update({"cfg_scale_start_step": self.cfg_scale_start_step})
723723
md.update({"cfg_scale_end_step": self.cfg_scale_end_step})
@@ -735,7 +735,7 @@ def _loras_to_json(obj: Union[Any, list[Any]]):
735735
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
736736
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
737737
category="latents",
738-
version="1.0.0",
738+
version="1.1.0",
739739
)
740740
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):
741741
"""Run denoising process with a Z-Image transformer model + metadata."""
@@ -766,7 +766,7 @@ def _loras_to_json(obj: Union[Any, list[Any]]):
766766
md.update({"denoising_end": self.denoising_end})
767767
md.update({"scheduler": self.scheduler})
768768
md.update({"model": self.transformer.transformer})
769-
md.update({"seed": self.seed})
769+
md.update({"seed": self.noise.seed if self.noise is not None and self.noise.seed is not None else self.seed})
770770
if len(self.transformer.loras) > 0:
771771
md.update({"loras": _loras_to_json(self.transformer.loras)})
772772

0 commit comments

Comments
 (0)