@@ -49,7 +49,12 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
4949 t = dnw .sigma_to_t (sigma_in )
5050
5151 if shared .sd_model .is_sdxl :
52- eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
52+ num_classes_hack = shared .sd_model .model .diffusion_model .num_classes
53+ shared .sd_model .model .diffusion_model .num_classes = None
54+ try :
55+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
56+ finally :
57+ shared .sd_model .model .diffusion_model .num_classes = num_classes_hack
5358 else :
5459 eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
5560
@@ -78,13 +83,6 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
7883
7984# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
8085def find_noise_for_image_sigma_adjustment (p , cond , uncond , cfg_scale , steps ):
81- if shared .sd_model .is_sdxl :
82- cond_tensor = cond ['crossattn' ]
83- uncond_tensor = uncond ['crossattn' ]
84- cond_in = torch .cat ([uncond_tensor , cond_tensor ])
85- else :
86- cond_in = torch .cat ([uncond , cond ])
87-
8886 x = p .init_latent
8987
9088 s_in = x .new_ones ([x .shape [0 ]])
@@ -124,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
124122
125123
126124 if shared .sd_model .is_sdxl :
127- eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
125+ num_classes_hack = shared .sd_model .model .diffusion_model .num_classes
126+ shared .sd_model .model .diffusion_model .num_classes = None
127+ try :
128+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
129+ finally :
130+ shared .sd_model .model .diffusion_model .num_classes = num_classes_hack
128131 else :
129132 eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
130133
@@ -211,9 +214,19 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs
211214 and self .cache .sigma_adjustment == sigma_adjustment
212215 same_everything = same_params and self .cache .latent .shape == lat .shape and np .abs (self .cache .latent - lat ).sum () < 100
213216
217+ rand_noise = processing .create_random_tensors (p .init_latent .shape [1 :], seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength , seed_resize_from_h = p .seed_resize_from_h , seed_resize_from_w = p .seed_resize_from_w , p = p )
218+
214219 if same_everything :
215220 rec_noise = self .cache .noise
216221 else :
222+ # This prevents a crash, because I don't know how to access the underlying .diffusion_model yet when controlnet is enabled. WIP
223+ # modules.sd_unet -> we're good
224+ # scripts.hook -> we're cooked
225+ if "scripts.hook" in str (shared .sd_model .model .diffusion_model .forward .__module__ ):
226+ print ("turn off any controlnets, do 1 pass and then turn controlnet back on to cache noise" )
227+ p .steps = 1
228+ return sd_samplers .create_sampler (p .sampler_name , p .sd_model ).sample_img2img (p , p .init_latent , rand_noise , conditioning , unconditional_conditioning , image_conditioning = p .image_conditioning )
229+
217230 shared .state .job_count += 1
218231 cond = p .sd_model .get_learned_conditioning (p .batch_size * [original_prompt ])
219232 uncond = p .sd_model .get_learned_conditioning (p .batch_size * [original_negative_prompt ])
@@ -223,8 +236,6 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs
223236 rec_noise = find_noise_for_image (p , cond , uncond , cfg , st )
224237 self .cache = Cached (rec_noise , cfg , st , lat , original_prompt , original_negative_prompt , sigma_adjustment )
225238
226- rand_noise = processing .create_random_tensors (p .init_latent .shape [1 :], seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength , seed_resize_from_h = p .seed_resize_from_h , seed_resize_from_w = p .seed_resize_from_w , p = p )
227-
228239 combined_noise = ((1 - randomness ) * rec_noise + randomness * rand_noise ) / ((randomness ** 2 + (1 - randomness )** 2 ) ** 0.5 )
229240
230241 sampler = sd_samplers .create_sampler (p .sampler_name , p .sd_model )
0 commit comments