File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -96,12 +96,20 @@ def sample(
9696 # For details, refer to line 3 of Algorithm 2 on page 4 of the paper
9797 noise = torch .randn_like (x ) if i > 1 else torch .zeros_like (x )
9898
99+ # Fix latent diffusion explosion problem
100+ if self .latent :
101+ x = x .clamp (- 1 , 1 )
102+
99103 # In each epoch, use x to calculate t - 1 of x
100104 # For details, refer to line 4 of Algorithm 2 on page 4 of the paper
101105 x = 1 / torch .sqrt (alpha ) * (
102106 x - ((1 - alpha ) / (torch .sqrt (1 - alpha_hat ))) * predicted_noise ) + torch .sqrt (
103107 beta ) * noise
104- # Post process
105- x = self .post_process (x = x .clamp (- 1 , 1 ))
108+ # Post process, output of constraint x
109+ if self .latent :
110+ x = self .post_process (x = x )
111+ else :
112+ x = self .post_process (x = x .clamp (- 1 , 1 ))
113+
106114 model .train ()
107115 return x
You can’t perform that action at this time.
0 commit comments