@@ -148,50 +148,51 @@ def __call__(
148148 # Get primary attribution scores
149149 output .token_attributions = None
150150 output .normalized_token_attributions = None
151- if calculate_attributions and attribution_method == 'grad_x_input' :
151+ if calculate_attributions :
152+ if attribution_method == 'grad_x_input' :
152153
153- if self .verbose :
154- print ("Calculating token attributions... " , end = '' )
154+ if self .verbose :
155+ print ("Calculating token attributions... " , end = '' )
155156
156- token_attributions = gradient_x_inputs_attribution (
157- pred_logits = output .image , input_embeds = text_embeddings ,
158- explanation_2d_bounding_box = explanation_2d_bounding_box
159- )
160- token_attributions = token_attributions .detach ().cpu ().numpy ()
161-
162- # remove special tokens
163- assert len (token_attributions ) == len (tokens )
164- output .token_attributions = []
165- output .normalized_token_attributions = []
166- for image_token_attributions , image_tokens in zip (token_attributions , tokens ):
167- assert len (image_token_attributions ) == len (image_tokens )
168-
169- # Add token attributions
170- output .token_attributions .append ([])
171- for attr , token in zip (image_token_attributions , image_tokens ):
172- if consider_special_tokens or token not in self .special_tokens_attributes :
173-
174- if clean_token_prefixes_and_suffixes :
175- token = clean_token_from_prefixes_and_suffixes (token )
176-
177- output .token_attributions [- 1 ].append (
178- (token , attr )
179- )
180-
181- # Add normalized
182- total = sum ([attr for _ , attr in output .token_attributions [- 1 ]])
183- output .normalized_token_attributions .append (
184- [
185- (token , round (100 * attr / total , 3 ))
186- for token , attr in output .token_attributions [- 1 ]
187- ]
157+ token_attributions = gradient_x_inputs_attribution (
158+ pred_logits = output .image , input_embeds = text_embeddings ,
159+ explanation_2d_bounding_box = explanation_2d_bounding_box
188160 )
161+ token_attributions = token_attributions .detach ().cpu ().numpy ()
162+
163+ # remove special tokens
164+ assert len (token_attributions ) == len (tokens )
165+ output .token_attributions = []
166+ output .normalized_token_attributions = []
167+ for image_token_attributions , image_tokens in zip (token_attributions , tokens ):
168+ assert len (image_token_attributions ) == len (image_tokens )
169+
170+ # Add token attributions
171+ output .token_attributions .append ([])
172+ for attr , token in zip (image_token_attributions , image_tokens ):
173+ if consider_special_tokens or token not in self .special_tokens_attributes :
174+
175+ if clean_token_prefixes_and_suffixes :
176+ token = clean_token_from_prefixes_and_suffixes (token )
177+
178+ output .token_attributions [- 1 ].append (
179+ (token , attr )
180+ )
181+
182+ # Add normalized
183+ total = sum ([attr for _ , attr in output .token_attributions [- 1 ]])
184+ output .normalized_token_attributions .append (
185+ [
186+ (token , round (100 * attr / total , 3 ))
187+ for token , attr in output .token_attributions [- 1 ]
188+ ]
189+ )
190+
191+ if self .verbose :
192+ print ("Done!" )
189193
190- if self .verbose :
191- print ("Done!" )
192-
193- else :
194- raise NotImplementedError ("Only `attribution_method='grad_x_input'` is implemented for now" )
194+ else :
195+ raise NotImplementedError ("Only `attribution_method='grad_x_input'` is implemented for now" )
195196
196197 if batch_size == 1 :
197198 # squash batch dimension
0 commit comments