Skip to content

Commit 5fb28d5

Browse files
committed
fix n_last_diffusion_steps_to_consider_for_attributions=0
1 parent 659ef21 commit 5fb28d5

2 files changed

Lines changed: 42 additions & 41 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name='diffusers-interpret',
12-
version='0.3.0',
12+
version='0.3.1',
1313
description='diffusers-interpret: model explainability for 🤗 Diffusers',
1414
long_description=long_description,
1515
long_description_content_type='text/markdown',

src/diffusers_interpret/explainer.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -148,50 +148,51 @@ def __call__(
148148
# Get primary attribution scores
149149
output.token_attributions = None
150150
output.normalized_token_attributions = None
151-
if calculate_attributions and attribution_method == 'grad_x_input':
151+
if calculate_attributions:
152+
if attribution_method == 'grad_x_input':
152153

153-
if self.verbose:
154-
print("Calculating token attributions... ", end='')
154+
if self.verbose:
155+
print("Calculating token attributions... ", end='')
155156

156-
token_attributions = gradient_x_inputs_attribution(
157-
pred_logits=output.image, input_embeds=text_embeddings,
158-
explanation_2d_bounding_box=explanation_2d_bounding_box
159-
)
160-
token_attributions = token_attributions.detach().cpu().numpy()
161-
162-
# remove special tokens
163-
assert len(token_attributions) == len(tokens)
164-
output.token_attributions = []
165-
output.normalized_token_attributions = []
166-
for image_token_attributions, image_tokens in zip(token_attributions, tokens):
167-
assert len(image_token_attributions) == len(image_tokens)
168-
169-
# Add token attributions
170-
output.token_attributions.append([])
171-
for attr, token in zip(image_token_attributions, image_tokens):
172-
if consider_special_tokens or token not in self.special_tokens_attributes:
173-
174-
if clean_token_prefixes_and_suffixes:
175-
token = clean_token_from_prefixes_and_suffixes(token)
176-
177-
output.token_attributions[-1].append(
178-
(token, attr)
179-
)
180-
181-
# Add normalized
182-
total = sum([attr for _, attr in output.token_attributions[-1]])
183-
output.normalized_token_attributions.append(
184-
[
185-
(token, round(100 * attr / total, 3))
186-
for token, attr in output.token_attributions[-1]
187-
]
157+
token_attributions = gradient_x_inputs_attribution(
158+
pred_logits=output.image, input_embeds=text_embeddings,
159+
explanation_2d_bounding_box=explanation_2d_bounding_box
188160
)
161+
token_attributions = token_attributions.detach().cpu().numpy()
162+
163+
# remove special tokens
164+
assert len(token_attributions) == len(tokens)
165+
output.token_attributions = []
166+
output.normalized_token_attributions = []
167+
for image_token_attributions, image_tokens in zip(token_attributions, tokens):
168+
assert len(image_token_attributions) == len(image_tokens)
169+
170+
# Add token attributions
171+
output.token_attributions.append([])
172+
for attr, token in zip(image_token_attributions, image_tokens):
173+
if consider_special_tokens or token not in self.special_tokens_attributes:
174+
175+
if clean_token_prefixes_and_suffixes:
176+
token = clean_token_from_prefixes_and_suffixes(token)
177+
178+
output.token_attributions[-1].append(
179+
(token, attr)
180+
)
181+
182+
# Add normalized
183+
total = sum([attr for _, attr in output.token_attributions[-1]])
184+
output.normalized_token_attributions.append(
185+
[
186+
(token, round(100 * attr / total, 3))
187+
for token, attr in output.token_attributions[-1]
188+
]
189+
)
190+
191+
if self.verbose:
192+
print("Done!")
189193

190-
if self.verbose:
191-
print("Done!")
192-
193-
else:
194-
raise NotImplementedError("Only `attribution_method='grad_x_input'` is implemented for now")
194+
else:
195+
raise NotImplementedError("Only `attribution_method='grad_x_input'` is implemented for now")
195196

196197
if batch_size == 1:
197198
# squash batch dimension

0 commit comments

Comments
 (0)