1+ from typing import List , Optional
2+
13import numpy as np
24import torch
35
@@ -40,14 +42,16 @@ def map_attributions_to_char(attributions, offsets, text):
4042 if attributions .ndim == 1 :
4143 attributions = attributions [None , :]
4244
43- attributions_per_char = np .empty ((attributions .shape [0 ], len (text ))) # top_k, text_len
45+ attributions_per_char = np .zeros ((attributions .shape [0 ], len (text ))) # top_k, text_len
4446
4547 for token_idx , (start , end ) in enumerate (offsets ):
46- if start == end :
48+ if start == end : # skip special tokens
4749 continue
4850 attributions_per_char [:, start :end ] = attributions [:, token_idx ][:, None ]
4951
50- return attributions_per_char
52+ return np .exp (attributions_per_char ) / np .sum (
53+ np .exp (attributions_per_char ), axis = 1 , keepdims = True
54+ ) # softmax normalization
5155
5256
5357def map_attributions_to_word (attributions , word_ids ):
@@ -71,9 +75,14 @@ def map_attributions_to_word(attributions, word_ids):
7175 # Convert None to -1 for easier processing (PAD tokens)
7276 word_ids_int = np .array ([x if x is not None else - 1 for x in word_ids ], dtype = int )
7377
74- # Consider only tokens that belong to actual words (non-PAD)
78+ # Filter out PAD tokens from attributions and word_ids
79+ attributions = attributions [
80+ torch .arange (attributions .shape [0 ])[:, None ],
81+ torch .tensor (np .where (word_ids_int != - 1 )[0 ])[None , :],
82+ ]
83+ word_ids_int = word_ids_int [word_ids_int != - 1 ]
7584 unique_word_ids = np .unique (word_ids_int )
76- unique_word_ids = unique_word_ids [ unique_word_ids != - 1 ]
85+ num_unique_words = len ( unique_word_ids )
7786
7887 top_k = attributions .shape [0 ]
7988 attr_with_word_id = np .concat (
@@ -82,17 +91,25 @@ def map_attributions_to_word(attributions, word_ids):
8291 ) # top_k, seq_len, 2
8392 # last dim is 2: 0 is the attribution of the token, 1 is the word_id the token is associated to
8493
85- word_attributions = np .zeros ((top_k , len ( word_ids_int ) ))
94+ word_attributions = np .zeros ((top_k , num_unique_words ))
8695 for word_id in unique_word_ids :
8796 mask = attr_with_word_id [:, :, 1 ] == word_id # top_k, seq_len
8897 word_attributions [:, word_id ] = (attr_with_word_id [:, :, 0 ] * mask ).sum (
8998 axis = 1
9099 ) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
91100
92- return word_attributions
101+ # assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
102+ return np .exp (word_attributions ) / np .sum (
103+ np .exp (word_attributions ), axis = 1 , keepdims = True
104+ ) # softmax normalization
93105
94106
95- def plot_attributions_at_char (text , attributions_per_char , title = "Attributions" , figsize = (10 , 2 )):
107+ def plot_attributions_at_char (
108+ text : str ,
109+ attributions_per_char : np .ndarray ,
110+ figsize = (10 , 2 ),
111+ titles : Optional [List [str ]] = None ,
112+ ):
96113 """
97114 Plots character-level attributions as a heatmap.
98115 Args:
@@ -107,23 +124,26 @@ def plot_attributions_at_char(text, attributions_per_char, title="Attributions",
107124 raise ImportError (
108125 "matplotlib is required for plotting. Please install it to use this function."
109126 )
110-
111- plt .figure (figsize = figsize )
112- plt .imshow (attributions_per_char , aspect = "auto" , cmap = "viridis" )
113- plt .colorbar (label = "Attribution Score" )
114- plt .yticks (
115- ticks = np .arange (attributions_per_char .shape [0 ]),
116- labels = [f"Top { i + 1 } " for i in range (attributions_per_char .shape [0 ])],
117- )
118- plt .xticks (ticks = np .arange (len (text )), labels = list (text ), rotation = 90 )
119- plt .title (title )
120- plt .xlabel ("Characters in Text" )
121- plt .ylabel ("Top Predictions" )
122- plt .tight_layout ()
123- plt .show ()
124-
125-
126- def plot_attributions_at_word (text , attributions_per_word , title = "Attributions" , figsize = (10 , 2 )):
127+ top_k = attributions_per_char .shape [0 ]
128+
129+ all_plots = []
130+ for i in range (top_k ):
131+ fig , ax = plt .subplots (figsize = figsize )
132+ ax .bar (range (len (text )), attributions_per_char [i ])
133+ ax .set_xticks (np .arange (len (text )))
134+ ax .set_xticklabels (list (text ), rotation = 90 )
135+ title = titles [i ] if titles is not None else f"Attributions for Top { i + 1 } Prediction"
136+ ax .set_title (title )
137+ ax .set_xlabel ("Characters in Text" )
138+ ax .set_ylabel ("Top Predictions" )
139+ all_plots .append (fig )
140+
141+ return all_plots
142+
143+
144+ def plot_attributions_at_word (
145+ text , attributions_per_word , figsize = (10 , 2 ), titles : Optional [List [str ]] = None
146+ ):
127147 """
128148 Plots word-level attributions as a heatmap.
129149 Args:
@@ -140,16 +160,25 @@ def plot_attributions_at_word(text, attributions_per_word, title="Attributions",
140160 )
141161
142162 words = text .split ()
143- plt .figure (figsize = figsize )
144- plt .imshow (attributions_per_word , aspect = "auto" , cmap = "viridis" )
145- plt .colorbar (label = "Attribution Score" )
146- plt .yticks (
147- ticks = np .arange (attributions_per_word .shape [0 ]),
148- labels = [f"Top { i + 1 } " for i in range (attributions_per_word .shape [0 ])],
149- )
150- plt .xticks (ticks = np .arange (len (words )), labels = words , rotation = 90 )
151- plt .title (title )
152- plt .xlabel ("Words in Text" )
153- plt .ylabel ("Top Predictions" )
154- plt .tight_layout ()
163+ top_k = attributions_per_word .shape [0 ]
164+ all_plots = []
165+ for i in range (top_k ):
166+ fig , ax = plt .subplots (figsize = figsize )
167+ ax .bar (range (len (words )), attributions_per_word [i ])
168+ ax .set_xticks (np .arange (len (words )))
169+ ax .set_xticklabels (words , rotation = 90 )
170+ title = titles [i ] if titles is not None else f"Attributions for Top { i + 1 } Prediction"
171+ ax .set_title (title )
172+ ax .set_xlabel ("Words in Text" )
173+ ax .set_ylabel ("Attributions" )
174+ all_plots .append (fig )
175+
176+ return all_plots
177+
178+
179+ def figshow (figure ):
180+ # https://stackoverflow.com/questions/53088212/create-multiple-figures-in-pyplot-but-only-show-one
181+ for i in plt .get_fignums ():
182+ if figure != plt .figure (i ):
183+ plt .close (plt .figure (i ))
155184 plt .show ()
0 commit comments