diff --git a/CHANGELOG.md b/CHANGELOG.md index 76dd3ef40883..9aafc60efac7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Adapted `fidelity` and `unfaithfulness` to support `HeteroExplanation`, propagating per-type node and edge masks (and per-node-type top-k feature selection) ([#10677](https://github.com/pyg-team/pytorch_geometric/pull/10677)) + ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) diff --git a/test/explain/metric/test_faithfulness.py b/test/explain/metric/test_faithfulness.py index 627a6093076e..f406ab7ab816 100644 --- a/test/explain/metric/test_faithfulness.py +++ b/test/explain/metric/test_faithfulness.py @@ -1,12 +1,14 @@ import pytest import torch +from torch_geometric.data import HeteroData from torch_geometric.explain import ( DummyExplainer, Explainer, ModelConfig, unfaithfulness, ) +from torch_geometric.testing import get_random_edge_index class DummyModel(torch.nn.Module): @@ -22,6 +24,20 @@ def forward(self, x, edge_index): return x +class HeteroDummyModel(torch.nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + def forward(self, x_dict, edge_index_dict, **kwargs): + x = x_dict['paper'] + if self.model_config.return_type.value == 'probs': + x = x.softmax(dim=-1) + elif self.model_config.return_type.value == 'log_probs': + x = x.log_softmax(dim=-1) + return x + + @pytest.mark.parametrize('top_k', [None, 2]) @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) @pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes']) @@ -57,3 +73,47 @@ def test_unfaithfulness(top_k, explanation_type, node_mask_type, return_type): metric = unfaithfulness(explainer, explanation, top_k) assert metric >= 0. and metric <= 1. + + +@pytest.mark.parametrize('top_k', [None, 2]) +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +@pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes']) +@pytest.mark.parametrize('return_type', ['raw', 'probs', 'log_probs']) +def test_unfaithfulness_hetero(top_k, explanation_type, node_mask_type, + return_type): + data = HeteroData() + data['paper'].x = torch.randn(8, 4) + data['author'].x = torch.randn(10, 4) + data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10) + data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10) + data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10) + + model_config = ModelConfig( + mode='multiclass_classification', + task_level='node', + return_type=return_type, + ) + + explainer = Explainer( + HeteroDummyModel(model_config), + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=model_config, + ) + + target = None + if explanation_type == 'phenomenon': + target = torch.randint(0, data['paper'].x.size(1), + (data['paper'].x.size(0), )) + + explanation = explainer( + data.x_dict, + data.edge_index_dict, + target=target, + index=torch.arange(4), + ) + + metric = unfaithfulness(explainer, explanation, top_k) + assert metric >= 0. and metric <= 1. diff --git a/test/explain/metric/test_fidelity.py b/test/explain/metric/test_fidelity.py index 1e5c4bd0171e..19df6466c942 100644 --- a/test/explain/metric/test_fidelity.py +++ b/test/explain/metric/test_fidelity.py @@ -1,6 +1,7 @@ import pytest import torch +from torch_geometric.data import HeteroData from torch_geometric.explain import ( DummyExplainer, Explainer, @@ -8,6 +9,7 @@ fidelity, fidelity_curve_auc, ) +from torch_geometric.testing import get_random_edge_index class DummyModel(torch.nn.Module): @@ -15,6 +17,11 @@ def forward(self, x, edge_index): return x +class HeteroDummyModel(torch.nn.Module): + def forward(self, x_dict, edge_index_dict, **kwargs): + return x_dict['paper'] + + @pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) def test_fidelity(explanation_type): x = torch.randn(8, 4) @@ -47,6 +54,47 @@ def test_fidelity(explanation_type): assert pos_fidelity == 0.0 and neg_fidelity == 0.0 +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +def test_fidelity_hetero(explanation_type): + data = HeteroData() + data['paper'].x = torch.randn(8, 4) + data['author'].x = torch.randn(10, 4) + data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10) + data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10) + data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10) + + explainer = Explainer( + HeteroDummyModel(), + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type='object', + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + return_type='raw', + task_level='node', + ), + ) + + target = None + if explanation_type == 'phenomenon': + target = torch.randint(0, data['paper'].x.size(1), + (data['paper'].x.size(0), )) + + explanation = explainer( + data.x_dict, + data.edge_index_dict, + target=target, + index=torch.arange(4), + ) + + pos_fidelity, neg_fidelity = fidelity(explainer, explanation) + # `HeteroDummyModel` returns `x_dict['paper']` and the explainer applies + # object-level masks which uniformly scale each node's feature row, so the + # predicted class is preserved and both fidelity scores collapse to zero. + assert pos_fidelity == 0.0 and neg_fidelity == 0.0 + + def test_characterization_score(): out = characterization_score( pos_fidelity=torch.tensor([1.0, 0.6, 0.5, 1.0]), diff --git a/torch_geometric/explain/metric/faithfulness.py b/torch_geometric/explain/metric/faithfulness.py index 4ed090f92995..3914b09d8164 100644 --- a/torch_geometric/explain/metric/faithfulness.py +++ b/torch_geometric/explain/metric/faithfulness.py @@ -1,15 +1,16 @@ -from typing import Optional +from typing import Optional, Union import torch import torch.nn.functional as F -from torch_geometric.explain import Explainer, Explanation +from torch_geometric.explain import Explainer, Explanation, HeteroExplanation from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType +from torch_geometric.explain.metric.fidelity import _hetero_model_kwargs def unfaithfulness( explainer: Explainer, - explanation: Explanation, + explanation: Union[Explanation, HeteroExplanation], top_k: Optional[int] = None, ) -> float: r"""Evaluates how faithful an :class:`~torch_geometric.explain.Explanation` @@ -43,6 +44,9 @@ def unfaithfulness( raise ValueError("Cannot apply top-k feature selection based on a " "node mask of type 'object'") + if isinstance(explanation, HeteroExplanation): + return _hetero_unfaithfulness(explainer, explanation, top_k) + node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') x, edge_index = explanation.x, explanation.edge_index @@ -71,3 +75,44 @@ def unfaithfulness( kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean') return 1 - float(torch.exp(-kl_div)) + + +def _hetero_unfaithfulness( + explainer: Explainer, + explanation: HeteroExplanation, + top_k: Optional[int] = None, +) -> float: + node_mask = explanation.collect('node_mask', allow_empty=True) or None + edge_mask = explanation.collect('edge_mask', allow_empty=True) or None + x = explanation.x_dict + edge_index = explanation.edge_index_dict + kwargs = _hetero_model_kwargs(explanation) + + y = explanation.get('prediction') + if y is None: # == ExplanationType.phenomenon + y = explainer.get_prediction(x, edge_index, **kwargs) + + if node_mask is not None and top_k is not None: + new_node_mask = {} + for node_type, mask in node_mask.items(): + feat_importance = mask.sum(dim=0) + k = min(top_k, feat_importance.size(0)) + _, top_k_index = feat_importance.topk(k) + new_mask = torch.zeros_like(mask) + new_mask[:, top_k_index] = 1.0 + new_node_mask[node_type] = new_mask + node_mask = new_node_mask + + y_hat = explainer.get_masked_prediction(x, edge_index, node_mask, + edge_mask, **kwargs) + + if explanation.get('index') is not None: + y, y_hat = y[explanation.index], y_hat[explanation.index] + + if explainer.model_config.return_type == ModelReturnType.raw: + y, y_hat = y.softmax(dim=-1), y_hat.softmax(dim=-1) + elif explainer.model_config.return_type == ModelReturnType.log_probs: + y, y_hat = y.exp(), y_hat.exp() + + kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean') + return 1 - float(torch.exp(-kl_div)) diff --git a/torch_geometric/explain/metric/fidelity.py b/torch_geometric/explain/metric/fidelity.py index bf9cf2f14b00..7a073bbfe4c8 100644 --- a/torch_geometric/explain/metric/fidelity.py +++ b/torch_geometric/explain/metric/fidelity.py @@ -1,15 +1,28 @@ -from typing import Tuple +from typing import Any, Dict, Tuple, Union import torch from torch import Tensor -from torch_geometric.explain import Explainer, Explanation +from torch_geometric.explain import Explainer, Explanation, HeteroExplanation from torch_geometric.explain.config import ExplanationType, ModelMode +def _hetero_model_kwargs(explanation: HeteroExplanation) -> Dict[str, Any]: + # Reconstructs the keyword arguments originally passed to the model. + # Falls back to an empty dict when `_model_args` was not populated + # (e.g. by an older `Explainer`); see PR #10672. + kwargs: Dict[str, Any] = {} + for key in getattr(explanation, '_model_args', []): + if key.endswith('_dict'): + kwargs[key] = explanation.collect(key[:-5], allow_empty=True) + else: + kwargs[key] = explanation[key] + return kwargs + + def fidelity( explainer: Explainer, - explanation: Explanation, + explanation: Union[Explanation, HeteroExplanation], ) -> Tuple[float, float]: r"""Evaluates the fidelity of an :class:`~torch_geometric.explain.Explainer` given an @@ -50,6 +63,9 @@ def fidelity( if explainer.model_config.mode == ModelMode.regression: raise ValueError("Fidelity not defined for 'regression' models") + if isinstance(explanation, HeteroExplanation): + return _hetero_fidelity(explainer, explanation) + node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') kwargs = {key: explanation[key] for key in explanation._model_args} @@ -100,6 +116,66 @@ def fidelity( return float(pos_fidelity), float(neg_fidelity) +def _hetero_fidelity( + explainer: Explainer, + explanation: HeteroExplanation, +) -> Tuple[float, float]: + node_mask = explanation.collect('node_mask', allow_empty=True) or None + edge_mask = explanation.collect('edge_mask', allow_empty=True) or None + kwargs = _hetero_model_kwargs(explanation) + + y = explanation.target + if explainer.explanation_type == ExplanationType.phenomenon: + y_hat = explainer.get_prediction( + explanation.x_dict, + explanation.edge_index_dict, + **kwargs, + ) + y_hat = explainer.get_target(y_hat) + + explain_y_hat = explainer.get_masked_prediction( + explanation.x_dict, + explanation.edge_index_dict, + node_mask, + edge_mask, + **kwargs, + ) + explain_y_hat = explainer.get_target(explain_y_hat) + + complement_y_hat = explainer.get_masked_prediction( + explanation.x_dict, + explanation.edge_index_dict, + { + k: 1. - v + for k, v in node_mask.items() + } if node_mask is not None else None, + { + k: 1. - v + for k, v in edge_mask.items() + } if edge_mask is not None else None, + **kwargs, + ) + complement_y_hat = explainer.get_target(complement_y_hat) + + if explanation.get('index') is not None: + y = y[explanation.index] + if explainer.explanation_type == ExplanationType.phenomenon: + y_hat = y_hat[explanation.index] + explain_y_hat = explain_y_hat[explanation.index] + complement_y_hat = complement_y_hat[explanation.index] + + if explainer.explanation_type == ExplanationType.model: + pos_fidelity = 1. - (complement_y_hat == y).float().mean() + neg_fidelity = 1. - (explain_y_hat == y).float().mean() + else: + pos_fidelity = ((y_hat == y).float() - + (complement_y_hat == y).float()).abs().mean() + neg_fidelity = ((y_hat == y).float() - + (explain_y_hat == y).float()).abs().mean() + + return float(pos_fidelity), float(neg_fidelity) + + def characterization_score( pos_fidelity: Tensor, neg_fidelity: Tensor,