2828from invokeai .app .invocations .ip_adapter import IPAdapterField
2929from invokeai .app .invocations .model import ControlLoRAField , LoRAField , TransformerField , VAEField
3030from invokeai .app .invocations .primitives import LatentsOutput
31+ from invokeai .app .invocations .universal_noise import validate_noise_tensor_shape
3132from invokeai .app .services .shared .invocation_context import InvocationContext
3233from invokeai .backend .flux .controlnet .instantx_controlnet_flux import InstantXControlNetFlux
3334from invokeai .backend .flux .controlnet .xlabs_controlnet_flux import XLabsControlNetFlux
7172 title = "FLUX Denoise" ,
7273 tags = ["image" , "flux" ],
7374 category = "image" ,
74- version = "4.5.1 " ,
75+ version = "4.6.0 " ,
7576)
7677class FluxDenoiseInvocation (BaseInvocation ):
7778 """Run denoising process with a FLUX transformer model."""
@@ -82,6 +83,11 @@ class FluxDenoiseInvocation(BaseInvocation):
8283 description = FieldDescriptions .latents ,
8384 input = Input .Connection ,
8485 )
86+ noise : Optional [LatentsField ] = InputField (
87+ default = None ,
88+ description = FieldDescriptions .noise ,
89+ input = Input .Connection ,
90+ )
8591 # denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
8692 denoise_mask : Optional [DenoiseMaskField ] = InputField (
8793 default = None ,
@@ -211,21 +217,15 @@ def _run_diffusion(
211217 context : InvocationContext ,
212218 ):
213219 inference_dtype = torch .bfloat16
220+ device = TorchDevice .choose_torch_device ()
214221
215222 # Load the input latents, if provided.
216223 init_latents = context .tensors .load (self .latents .latents_name ) if self .latents else None
217224 if init_latents is not None :
218- init_latents = init_latents .to (device = TorchDevice . choose_torch_device () , dtype = inference_dtype )
225+ init_latents = init_latents .to (device = device , dtype = inference_dtype )
219226
220227 # Prepare input noise.
221- noise = get_noise (
222- num_samples = 1 ,
223- height = self .height ,
224- width = self .width ,
225- device = TorchDevice .choose_torch_device (),
226- dtype = inference_dtype ,
227- seed = self .seed ,
228- )
228+ noise = self ._prepare_noise_tensor (context , inference_dtype , device )
229229 b , _c , latent_h , latent_w = noise .shape
230230 packed_h = latent_h // 2
231231 packed_w = latent_w // 2
@@ -237,7 +237,7 @@ def _run_diffusion(
237237 packed_height = packed_h ,
238238 packed_width = packed_w ,
239239 dtype = inference_dtype ,
240- device = TorchDevice . choose_torch_device () ,
240+ device = device ,
241241 )
242242 neg_text_conditionings : list [FluxTextConditioning ] | None = None
243243 if self .negative_text_conditioning is not None :
@@ -247,14 +247,14 @@ def _run_diffusion(
247247 packed_height = packed_h ,
248248 packed_width = packed_w ,
249249 dtype = inference_dtype ,
250- device = TorchDevice . choose_torch_device () ,
250+ device = device ,
251251 )
252252 redux_conditionings : list [FluxReduxConditioning ] = self ._load_redux_conditioning (
253253 context = context ,
254254 redux_cond_field = self .redux_conditioning ,
255255 packed_height = packed_h ,
256256 packed_width = packed_w ,
257- device = TorchDevice . choose_torch_device () ,
257+ device = device ,
258258 dtype = inference_dtype ,
259259 )
260260 pos_regional_prompting_extension = RegionalPromptingExtension .from_text_conditioning (
@@ -331,9 +331,7 @@ def _run_diffusion(
331331 img_cond : torch .Tensor | None = None
332332 is_flux_fill = transformer_config .variant is FluxVariantType .DevFill
333333 if is_flux_fill :
334- img_cond = self ._prep_flux_fill_img_cond (
335- context , device = TorchDevice .choose_torch_device (), dtype = inference_dtype
336- )
334+ img_cond = self ._prep_flux_fill_img_cond (context , device = device , dtype = inference_dtype )
337335 else :
338336 if self .fill_conditioning is not None :
339337 raise ValueError ("fill_conditioning was provided, but the model is not a FLUX Fill model." )
@@ -391,7 +389,7 @@ def _run_diffusion(
391389 if isinstance (self .kontext_conditioning , list )
392390 else [self .kontext_conditioning ],
393391 vae_field = self .controlnet_vae ,
394- device = TorchDevice . choose_torch_device () ,
392+ device = device ,
395393 dtype = inference_dtype ,
396394 )
397395
@@ -508,6 +506,23 @@ def _run_diffusion(
508506 x = unpack (x .float (), self .height , self .width )
509507 return x
510508
509+ def _prepare_noise_tensor (
510+ self , context : InvocationContext , inference_dtype : torch .dtype , device : torch .device
511+ ) -> torch .Tensor :
512+ if self .noise is not None :
513+ noise = context .tensors .load (self .noise .latents_name ).to (device = device , dtype = inference_dtype )
514+ validate_noise_tensor_shape (noise , "FLUX" , self .width , self .height )
515+ return noise
516+
517+ return get_noise (
518+ num_samples = 1 ,
519+ height = self .height ,
520+ width = self .width ,
521+ device = device ,
522+ dtype = inference_dtype ,
523+ seed = self .seed ,
524+ )
525+
511526 def _load_text_conditioning (
512527 self ,
513528 context : InvocationContext ,
0 commit comments