@@ -492,13 +492,15 @@ def predict(
492492 self ,
493493 X_test : np .ndarray ,
494494 top_k = 1 ,
495- explain = False ,
495+ explain_with_label_attention : bool = False ,
496+ explain_with_captum = False ,
496497 ):
497498 """
498499 Args:
499500 X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
500501 top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
501- explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
502+ explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False)
503+ explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False)
502504
503505 Returns: A dictionary containing the following fields:
504506 - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
@@ -507,6 +509,7 @@ def predict(
507509 - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
508510 """
509511
512+ explain = explain_with_label_attention or explain_with_captum
510513 if explain :
511514 return_offsets_mapping = True # to be passed to the tokenizer
512515 return_word_ids = True
@@ -515,13 +518,19 @@ def predict(
515518 "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
516519 )
517520 else :
518- if not HAS_CAPTUM :
519- raise ImportError (
520- "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
521- )
522- lig = LayerIntegratedGradients (
523- self .pytorch_model , self .pytorch_model .text_embedder .embedding_layer
524- ) # initialize a Captum layer gradient integrator
521+ if explain_with_captum :
522+ if not HAS_CAPTUM :
523+ raise ImportError (
524+ "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
525+ )
526+ lig = LayerIntegratedGradients (
527+ self .pytorch_model , self .pytorch_model .text_embedder .embedding_layer
528+ ) # initialize a Captum layer gradient integrator
529+ if explain_with_label_attention :
530+ if not self .enable_label_attention :
531+ raise RuntimeError (
532+ "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain."
533+ )
525534 else :
526535 return_offsets_mapping = False
527536 return_word_ids = False
@@ -553,9 +562,19 @@ def predict(
553562 else :
554563 categorical_vars = torch .empty ((encoded_text .shape [0 ], 0 ), dtype = torch .float32 )
555564
556- pred = self .pytorch_model (
557- encoded_text , attention_mask , categorical_vars
565+ model_output = self .pytorch_model (
566+ encoded_text ,
567+ attention_mask ,
568+ categorical_vars ,
569+ return_label_attention_matrix = explain_with_label_attention ,
558570 ) # forward pass, contains the prediction scores (len(text), num_classes)
571+ pred = (
572+ model_output ["logits" ] if explain_with_label_attention else model_output
573+ ) # (batch_size, num_classes)
574+
575+ label_attention_matrix = (
576+ model_output ["label_attention_matrix" ] if explain_with_label_attention else None
577+ )
559578
560579 label_scores = pred .detach ().cpu ().softmax (dim = 1 ) # convert to probabilities
561580
@@ -565,21 +584,28 @@ def predict(
565584 confidence = torch .round (label_scores_topk .values , decimals = 2 ) # and their scores
566585
567586 if explain :
568- all_attributions = []
569- for k in range (top_k ):
570- attributions = lig .attribute (
571- (encoded_text , attention_mask , categorical_vars ),
572- target = torch .Tensor (predictions [:, k ]).long (),
573- ) # (batch_size, seq_len)
574- attributions = attributions .sum (dim = - 1 )
575- all_attributions .append (attributions .detach ().cpu ())
576-
577- all_attributions = torch .stack (all_attributions , dim = 1 ) # (batch_size, top_k, seq_len)
587+ if explain_with_captum :
588+ # Captum explanations
589+ captum_attributions = []
590+ for k in range (top_k ):
591+ attributions = lig .attribute (
592+ (encoded_text , attention_mask , categorical_vars ),
593+ target = torch .Tensor (predictions [:, k ]).long (),
594+ ) # (batch_size, seq_len)
595+ attributions = attributions .sum (dim = - 1 )
596+ captum_attributions .append (attributions .detach ().cpu ())
597+
598+ captum_attributions = torch .stack (
599+ captum_attributions , dim = 1
600+ ) # (batch_size, top_k, seq_len)
601+ else :
602+ captum_attributions = None
578603
579604 return {
580605 "prediction" : predictions ,
581606 "confidence" : confidence ,
582- "attributions" : all_attributions ,
607+ "captum_attributions" : captum_attributions ,
608+ "label_attention_attributions" : label_attention_matrix ,
583609 "offset_mapping" : tokenize_output .offset_mapping ,
584610 "word_ids" : tokenize_output .word_ids ,
585611 }
0 commit comments