Skip to content

Commit a2fe33e

Browse files
feat(explainability): add new expl. pipe. with label attention
- given a parameter, retrieve the attention matrix - compatible with captum attributions - update tests accordingly
1 parent 0a1880b commit a2fe33e

4 files changed

Lines changed: 77 additions & 32 deletions

File tree

tests/test_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,19 @@ def run_full_pipeline(
189189

190190
# Predict with explanations
191191
top_k = 5
192-
predictions = ttc.predict(X, top_k=top_k, explain=True)
192+
193+
predictions = ttc.predict(
194+
X,
195+
top_k=top_k,
196+
explain_with_label_attention=label_attention_enabled,
197+
explain_with_captum=True,
198+
)
193199

194200
# Test explainability functions
195201
text_idx = 0
196202
text = sample_text_data[text_idx]
197203
offsets = predictions["offset_mapping"][text_idx]
198-
attributions = predictions["attributions"][text_idx]
204+
attributions = predictions["captum_attributions"][text_idx]
199205
word_ids = predictions["word_ids"][text_idx]
200206

201207
words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets)

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,10 @@ def forward(
170170
return_label_attention_matrix=return_label_attention_matrix,
171171
).values()
172172

173-
if return_label_attention_matrix:
174-
return (
175-
text_embedding,
176-
label_attention_matrix,
177-
) # label_attention_matrix is None if label attention is disabled
178-
else:
179-
return text_embedding
173+
return {
174+
"sentence_embedding": text_embedding,
175+
"label_attention_matrix": label_attention_matrix,
176+
}
180177

181178
def _get_sentence_embedding(
182179
self,
@@ -304,6 +301,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F
304301
305302
"""
306303
B, T, C = token_embeddings.size()
304+
if isinstance(compute_attention_matrix, torch.Tensor):
305+
compute_attention_matrix = compute_attention_matrix[0].item()
306+
compute_attention_matrix = bool(compute_attention_matrix)
307307

308308
# 1. Create label indices [0, 1, ..., C-1] for the whole batch
309309
label_indices = torch.arange(self.num_classes).expand(B, -1)

torchTextClassifiers/model/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def forward(
118118
input_ids: Annotated[torch.Tensor, "batch seq_len"],
119119
attention_mask: Annotated[torch.Tensor, "batch seq_len"],
120120
categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
121+
return_label_attention_matrix: bool = False,
121122
**kwargs,
122123
) -> torch.Tensor:
123124
"""
@@ -136,7 +137,16 @@ def forward(
136137
if self.text_embedder is None:
137138
x_text = encoded_text.float()
138139
else:
139-
x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask)
140+
text_embed_output = self.text_embedder(
141+
input_ids=encoded_text,
142+
attention_mask=attention_mask,
143+
return_label_attention_matrix=return_label_attention_matrix,
144+
)
145+
x_text = text_embed_output["sentence_embedding"]
146+
if isinstance(return_label_attention_matrix, torch.Tensor):
147+
return_label_attention_matrix = return_label_attention_matrix[0].item()
148+
if return_label_attention_matrix:
149+
label_attention_matrix = text_embed_output["label_attention_matrix"]
140150

141151
if self.categorical_variable_net:
142152
x_cat = self.categorical_variable_net(categorical_vars)
@@ -166,4 +176,7 @@ def forward(
166176

167177
logits = self.classification_head(norm(x_combined)).squeeze(-1)
168178

179+
if return_label_attention_matrix:
180+
return {"logits": logits, "label_attention_matrix": label_attention_matrix}
181+
169182
return logits

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,15 @@ def predict(
492492
self,
493493
X_test: np.ndarray,
494494
top_k=1,
495-
explain=False,
495+
explain_with_label_attention: bool = False,
496+
explain_with_captum=False,
496497
):
497498
"""
498499
Args:
499500
X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
500501
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
501-
explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
502+
explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False)
503+
explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False)
502504
503505
Returns: A dictionary containing the following fields:
504506
- predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
@@ -507,6 +509,7 @@ def predict(
507509
- attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
508510
"""
509511

512+
explain = explain_with_label_attention or explain_with_captum
510513
if explain:
511514
return_offsets_mapping = True # to be passed to the tokenizer
512515
return_word_ids = True
@@ -515,13 +518,19 @@ def predict(
515518
"Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
516519
)
517520
else:
518-
if not HAS_CAPTUM:
519-
raise ImportError(
520-
"Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
521-
)
522-
lig = LayerIntegratedGradients(
523-
self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
524-
) # initialize a Captum layer gradient integrator
521+
if explain_with_captum:
522+
if not HAS_CAPTUM:
523+
raise ImportError(
524+
"Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
525+
)
526+
lig = LayerIntegratedGradients(
527+
self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
528+
) # initialize a Captum layer gradient integrator
529+
if explain_with_label_attention:
530+
if not self.enable_label_attention:
531+
raise RuntimeError(
532+
"Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain."
533+
)
525534
else:
526535
return_offsets_mapping = False
527536
return_word_ids = False
@@ -553,9 +562,19 @@ def predict(
553562
else:
554563
categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
555564

556-
pred = self.pytorch_model(
557-
encoded_text, attention_mask, categorical_vars
565+
model_output = self.pytorch_model(
566+
encoded_text,
567+
attention_mask,
568+
categorical_vars,
569+
return_label_attention_matrix=explain_with_label_attention,
558570
) # forward pass, contains the prediction scores (len(text), num_classes)
571+
pred = (
572+
model_output["logits"] if explain_with_label_attention else model_output
573+
) # (batch_size, num_classes)
574+
575+
label_attention_matrix = (
576+
model_output["label_attention_matrix"] if explain_with_label_attention else None
577+
)
559578

560579
label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
561580

@@ -565,21 +584,28 @@ def predict(
565584
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
566585

567586
if explain:
568-
all_attributions = []
569-
for k in range(top_k):
570-
attributions = lig.attribute(
571-
(encoded_text, attention_mask, categorical_vars),
572-
target=torch.Tensor(predictions[:, k]).long(),
573-
) # (batch_size, seq_len)
574-
attributions = attributions.sum(dim=-1)
575-
all_attributions.append(attributions.detach().cpu())
576-
577-
all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
587+
if explain_with_captum:
588+
# Captum explanations
589+
captum_attributions = []
590+
for k in range(top_k):
591+
attributions = lig.attribute(
592+
(encoded_text, attention_mask, categorical_vars),
593+
target=torch.Tensor(predictions[:, k]).long(),
594+
) # (batch_size, seq_len)
595+
attributions = attributions.sum(dim=-1)
596+
captum_attributions.append(attributions.detach().cpu())
597+
598+
captum_attributions = torch.stack(
599+
captum_attributions, dim=1
600+
) # (batch_size, top_k, seq_len)
601+
else:
602+
captum_attributions = None
578603

579604
return {
580605
"prediction": predictions,
581606
"confidence": confidence,
582-
"attributions": all_attributions,
607+
"captum_attributions": captum_attributions,
608+
"label_attention_attributions": label_attention_matrix,
583609
"offset_mapping": tokenize_output.offset_mapping,
584610
"word_ids": tokenize_output.word_ids,
585611
}

0 commit comments

Comments
 (0)