Skip to content

Commit 5362f20

Browse files
committed
Add description of normalization step
I add a short description of the normalization step done to simplify the computation of the KL-divergence. I also fixed the variance scaling to use (1-lamb)**2 instead of (1-lamb). We fixed the bug for the computation and the results in the paper are with the correct scaling, but unfortunately the fix did not made it into this repo.
1 parent 071bba9 commit 5362f20

2 files changed

Lines changed: 35 additions & 12 deletions

File tree

attribution_bottleneck/bottleneck/per_sample_bottleneck.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,45 @@ def reset_alpha(self):
6060
self.alpha.fill_(self.initial_value)
6161
return self.alpha
6262

63-
def forward(self, x):
64-
""" Remove information from x by performing a sampling step, parametrized by the mask alpha """
63+
def forward(self, r):
64+
""" Remove information from r by performing a sampling step, parametrized by the mask alpha """
6565
# Smoothen and expand a on batch dimension
6666
lamb = self.sigmoid(self.alpha)
67-
lamb = lamb.expand(x.shape[0], x.shape[1], -1, -1)
67+
lamb = lamb.expand(r.shape[0], r.shape[1], -1, -1)
6868
lamb = self.smooth(lamb) if self.smooth is not None else lamb
6969

70-
# Normalize x
71-
x_norm = (x - self.mean) / self.std
70+
# We normalize r to simplify the computation of the KL-divergence
71+
#
72+
# The equation in the paper is:
73+
# Z = λ * R + (1 - λ) * ε)
74+
# where ε ~ N(μ_r, σ_r**2)
75+
# and given R the distribution of Z ~ N(λ * R, ((1 - λ) σ_r)**2)
76+
#
77+
# In the code μ_r = self.mean and σ_r = self.std.
78+
#
79+
# To simplify the computation of the KL-divergence we normalize:
80+
# R_norm = (R - μ_r) / σ_r
81+
# ε ~ N(0, 1)
82+
# Z_norm ~ N(λ * R_norm, (1 - λ))**2)
83+
# Z = σ_r * Z_norm + μ_r
84+
#
85+
# We compute KL[ N(λ * R_norm, (1 - λ))**2) || N(0, 1) ].
86+
#
87+
# The KL-divergence is invariant to scaling, see:
88+
# https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Properties
89+
90+
r_norm = (r - self.mean) / self.std
7291

7392
# Get sampling parameters
74-
mu, log_var = x_norm * lamb, torch.log(1-lamb)
93+
noise_var = (1-lamb)**2
94+
scaled_signal = r_norm * lamb
95+
noise_log_var = torch.log(noise_var)
7596

76-
# Sample new output values from p(z|x)
77-
z_norm = self._sample_z(mu, log_var)
78-
self.buffer_capacity = self._calc_capacity(mu, log_var)
97+
# Sample new output values from p(z|r)
98+
z_norm = self._sample_z(scaled_signal, noise_log_var)
99+
self.buffer_capacity = self._calc_capacity(scaled_signal, noise_log_var)
79100

80-
# Denormalize z to match magnitude of x
101+
# Denormalize z to match magnitude of r
81102
z = z_norm * self.std + self.mean
82103

83104
# Clamp output, if input was post-relu

attribution_bottleneck/bottleneck/readout_bottleneck.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,11 @@ def forward_augmented(self, x, readouts):
121121
# Smoothing step
122122
lamb = self.smooth[0](lamb) if self.smooth is not None else lamb
123123

124-
# Normalize x
124+
# Normalize x, see per_sample_bottleneck.py for an explanations
125125
x_norm = (x - self.mean_0) / self.std_0
126-
mu, log_var = x_norm * lamb, torch.log(1-lamb)
126+
mu = x_norm * lamb
127+
var = (1-lamb)**2
128+
log_var = torch.log(var)
127129

128130
# Sampling step
129131
# Sample new output values from p(z|x)

0 commit comments

Comments
 (0)