@@ -74,3 +74,51 @@ class CausalEncoderOutput(NamedTuple):
7474
7575class CausalDecoderOutput (NamedTuple ):
7676 sample : torch .Tensor
77+
78+
79+ class DecoderOutput :
80+ """Output of decoding method - matches diffusers.models.autoencoders.vae.DecoderOutput"""
81+ def __init__ (self , sample : torch .Tensor , commit_loss : Optional [torch .Tensor ] = None ):
82+ self .sample = sample
83+ self .commit_loss = commit_loss
84+
85+
86+ class DiagonalGaussianDistribution :
87+ """Matches diffusers.models.autoencoders.vae.DiagonalGaussianDistribution exactly."""
88+ def __init__ (self , parameters : torch .Tensor , deterministic : bool = False ):
89+ self .parameters = parameters
90+ self .mean , self .logvar = torch .chunk (parameters , 2 , dim = 1 )
91+ self .logvar = torch .clamp (self .logvar , - 30.0 , 20.0 )
92+ self .deterministic = deterministic
93+ self .std = torch .exp (0.5 * self .logvar )
94+ self .var = torch .exp (self .logvar )
95+ if self .deterministic :
96+ self .var = self .std = torch .zeros_like (
97+ self .mean , device = self .parameters .device , dtype = self .parameters .dtype
98+ )
99+
100+ def sample (self , generator : Optional [torch .Generator ] = None ) -> torch .Tensor :
101+ if self .deterministic :
102+ return self .mode ()
103+ sample = torch .randn (
104+ self .mean .shape ,
105+ generator = generator ,
106+ device = self .parameters .device ,
107+ dtype = self .parameters .dtype ,
108+ )
109+ return self .mean + self .std * sample
110+
111+ def mode (self ) -> torch .Tensor :
112+ return self .mean
113+
114+ def kl (self , other : Optional ["DiagonalGaussianDistribution" ] = None ) -> torch .Tensor :
115+ if other is None :
116+ return 0.5 * torch .sum (
117+ self .mean .pow (2 ) + self .var - 1.0 - self .logvar ,
118+ dim = [1 , 2 , 3 ],
119+ )
120+ return 0.5 * torch .sum (
121+ (self .mean - other .mean ).pow (2 ) / other .var
122+ + self .var / other .var - 1.0 - self .logvar + other .logvar ,
123+ dim = [1 , 2 , 3 ],
124+ )
0 commit comments