Skip to content

Commit 269c76a

Browse files
fix!(explainability): remove nan words and fix plotting
1 parent ea26799 commit 269c76a

1 file changed

Lines changed: 66 additions & 37 deletions

File tree

torchTextClassifiers/utilities/plot_explainability.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Optional
2+
13
import numpy as np
24
import torch
35

@@ -40,14 +42,16 @@ def map_attributions_to_char(attributions, offsets, text):
4042
if attributions.ndim == 1:
4143
attributions = attributions[None, :]
4244

43-
attributions_per_char = np.empty((attributions.shape[0], len(text))) # top_k, text_len
45+
attributions_per_char = np.zeros((attributions.shape[0], len(text))) # top_k, text_len
4446

4547
for token_idx, (start, end) in enumerate(offsets):
46-
if start == end:
48+
if start == end: # skip special tokens
4749
continue
4850
attributions_per_char[:, start:end] = attributions[:, token_idx][:, None]
4951

50-
return attributions_per_char
52+
return np.exp(attributions_per_char) / np.sum(
53+
np.exp(attributions_per_char), axis=1, keepdims=True
54+
) # softmax normalization
5155

5256

5357
def map_attributions_to_word(attributions, word_ids):
@@ -71,9 +75,14 @@ def map_attributions_to_word(attributions, word_ids):
7175
# Convert None to -1 for easier processing (PAD tokens)
7276
word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
7377

74-
# Consider only tokens that belong to actual words (non-PAD)
78+
# Filter out PAD tokens from attributions and word_ids
79+
attributions = attributions[
80+
torch.arange(attributions.shape[0])[:, None],
81+
torch.tensor(np.where(word_ids_int != -1)[0])[None, :],
82+
]
83+
word_ids_int = word_ids_int[word_ids_int != -1]
7584
unique_word_ids = np.unique(word_ids_int)
76-
unique_word_ids = unique_word_ids[unique_word_ids != -1]
85+
num_unique_words = len(unique_word_ids)
7786

7887
top_k = attributions.shape[0]
7988
attr_with_word_id = np.concat(
@@ -82,17 +91,25 @@ def map_attributions_to_word(attributions, word_ids):
8291
) # top_k, seq_len, 2
8392
# last dim is 2: 0 is the attribution of the token, 1 is the word_id the token is associated to
8493

85-
word_attributions = np.zeros((top_k, len(word_ids_int)))
94+
word_attributions = np.zeros((top_k, num_unique_words))
8695
for word_id in unique_word_ids:
8796
mask = attr_with_word_id[:, :, 1] == word_id # top_k, seq_len
8897
word_attributions[:, word_id] = (attr_with_word_id[:, :, 0] * mask).sum(
8998
axis=1
9099
) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
91100

92-
return word_attributions
101+
# assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
102+
return np.exp(word_attributions) / np.sum(
103+
np.exp(word_attributions), axis=1, keepdims=True
104+
) # softmax normalization
93105

94106

95-
def plot_attributions_at_char(text, attributions_per_char, title="Attributions", figsize=(10, 2)):
107+
def plot_attributions_at_char(
108+
text: str,
109+
attributions_per_char: np.ndarray,
110+
figsize=(10, 2),
111+
titles: Optional[List[str]] = None,
112+
):
96113
"""
97114
Plots character-level attributions as a heatmap.
98115
Args:
@@ -107,23 +124,26 @@ def plot_attributions_at_char(text, attributions_per_char, title="Attributions",
107124
raise ImportError(
108125
"matplotlib is required for plotting. Please install it to use this function."
109126
)
110-
111-
plt.figure(figsize=figsize)
112-
plt.imshow(attributions_per_char, aspect="auto", cmap="viridis")
113-
plt.colorbar(label="Attribution Score")
114-
plt.yticks(
115-
ticks=np.arange(attributions_per_char.shape[0]),
116-
labels=[f"Top {i+1}" for i in range(attributions_per_char.shape[0])],
117-
)
118-
plt.xticks(ticks=np.arange(len(text)), labels=list(text), rotation=90)
119-
plt.title(title)
120-
plt.xlabel("Characters in Text")
121-
plt.ylabel("Top Predictions")
122-
plt.tight_layout()
123-
plt.show()
124-
125-
126-
def plot_attributions_at_word(text, attributions_per_word, title="Attributions", figsize=(10, 2)):
127+
top_k = attributions_per_char.shape[0]
128+
129+
all_plots = []
130+
for i in range(top_k):
131+
fig, ax = plt.subplots(figsize=figsize)
132+
ax.bar(range(len(text)), attributions_per_char[i])
133+
ax.set_xticks(np.arange(len(text)))
134+
ax.set_xticklabels(list(text), rotation=90)
135+
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
136+
ax.set_title(title)
137+
ax.set_xlabel("Characters in Text")
138+
ax.set_ylabel("Top Predictions")
139+
all_plots.append(fig)
140+
141+
return all_plots
142+
143+
144+
def plot_attributions_at_word(
145+
text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
146+
):
127147
"""
128148
Plots word-level attributions as a heatmap.
129149
Args:
@@ -140,16 +160,25 @@ def plot_attributions_at_word(text, attributions_per_word, title="Attributions",
140160
)
141161

142162
words = text.split()
143-
plt.figure(figsize=figsize)
144-
plt.imshow(attributions_per_word, aspect="auto", cmap="viridis")
145-
plt.colorbar(label="Attribution Score")
146-
plt.yticks(
147-
ticks=np.arange(attributions_per_word.shape[0]),
148-
labels=[f"Top {i+1}" for i in range(attributions_per_word.shape[0])],
149-
)
150-
plt.xticks(ticks=np.arange(len(words)), labels=words, rotation=90)
151-
plt.title(title)
152-
plt.xlabel("Words in Text")
153-
plt.ylabel("Top Predictions")
154-
plt.tight_layout()
163+
top_k = attributions_per_word.shape[0]
164+
all_plots = []
165+
for i in range(top_k):
166+
fig, ax = plt.subplots(figsize=figsize)
167+
ax.bar(range(len(words)), attributions_per_word[i])
168+
ax.set_xticks(np.arange(len(words)))
169+
ax.set_xticklabels(words, rotation=90)
170+
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
171+
ax.set_title(title)
172+
ax.set_xlabel("Words in Text")
173+
ax.set_ylabel("Attributions")
174+
all_plots.append(fig)
175+
176+
return all_plots
177+
178+
179+
def figshow(figure):
180+
# https://stackoverflow.com/questions/53088212/create-multiple-figures-in-pyplot-but-only-show-one
181+
for i in plt.get_fignums():
182+
if figure != plt.figure(i):
183+
plt.close(plt.figure(i))
155184
plt.show()

0 commit comments

Comments
 (0)