Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Adapted `fidelity` and `unfaithfulness` to support `HeteroExplanation`, propagating per-type node and edge masks (and per-node-type top-k feature selection) ([#10677](https://github.com/pyg-team/pytorch_geometric/pull/10677))

### Changed

- Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))
Expand Down
60 changes: 60 additions & 0 deletions test/explain/metric/test_faithfulness.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
import torch

from torch_geometric.data import HeteroData
from torch_geometric.explain import (
DummyExplainer,
Explainer,
ModelConfig,
unfaithfulness,
)
from torch_geometric.testing import get_random_edge_index


class DummyModel(torch.nn.Module):
Expand All @@ -22,6 +24,20 @@ def forward(self, x, edge_index):
return x


class HeteroDummyModel(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

def forward(self, x_dict, edge_index_dict, **kwargs):
x = x_dict['paper']
if self.model_config.return_type.value == 'probs':
x = x.softmax(dim=-1)
elif self.model_config.return_type.value == 'log_probs':
x = x.log_softmax(dim=-1)
return x


@pytest.mark.parametrize('top_k', [None, 2])
@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])
@pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes'])
Expand Down Expand Up @@ -57,3 +73,47 @@ def test_unfaithfulness(top_k, explanation_type, node_mask_type, return_type):

metric = unfaithfulness(explainer, explanation, top_k)
assert metric >= 0. and metric <= 1.


@pytest.mark.parametrize('top_k', [None, 2])
@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])
@pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes'])
@pytest.mark.parametrize('return_type', ['raw', 'probs', 'log_probs'])
def test_unfaithfulness_hetero(top_k, explanation_type, node_mask_type,
return_type):
data = HeteroData()
data['paper'].x = torch.randn(8, 4)
data['author'].x = torch.randn(10, 4)
data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10)
data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10)
data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10)

model_config = ModelConfig(
mode='multiclass_classification',
task_level='node',
return_type=return_type,
)

explainer = Explainer(
HeteroDummyModel(model_config),
algorithm=DummyExplainer(),
explanation_type=explanation_type,
node_mask_type=node_mask_type,
edge_mask_type='object',
model_config=model_config,
)

target = None
if explanation_type == 'phenomenon':
target = torch.randint(0, data['paper'].x.size(1),
(data['paper'].x.size(0), ))

explanation = explainer(
data.x_dict,
data.edge_index_dict,
target=target,
index=torch.arange(4),
)

metric = unfaithfulness(explainer, explanation, top_k)
assert metric >= 0. and metric <= 1.
48 changes: 48 additions & 0 deletions test/explain/metric/test_fidelity.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import pytest
import torch

from torch_geometric.data import HeteroData
from torch_geometric.explain import (
DummyExplainer,
Explainer,
characterization_score,
fidelity,
fidelity_curve_auc,
)
from torch_geometric.testing import get_random_edge_index


class DummyModel(torch.nn.Module):
def forward(self, x, edge_index):
return x


class HeteroDummyModel(torch.nn.Module):
def forward(self, x_dict, edge_index_dict, **kwargs):
return x_dict['paper']


@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])
def test_fidelity(explanation_type):
x = torch.randn(8, 4)
Expand Down Expand Up @@ -47,6 +54,47 @@ def test_fidelity(explanation_type):
assert pos_fidelity == 0.0 and neg_fidelity == 0.0


@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])
def test_fidelity_hetero(explanation_type):
data = HeteroData()
data['paper'].x = torch.randn(8, 4)
data['author'].x = torch.randn(10, 4)
data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10)
data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10)
data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10)

explainer = Explainer(
HeteroDummyModel(),
algorithm=DummyExplainer(),
explanation_type=explanation_type,
node_mask_type='object',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
return_type='raw',
task_level='node',
),
)

target = None
if explanation_type == 'phenomenon':
target = torch.randint(0, data['paper'].x.size(1),
(data['paper'].x.size(0), ))

explanation = explainer(
data.x_dict,
data.edge_index_dict,
target=target,
index=torch.arange(4),
)

pos_fidelity, neg_fidelity = fidelity(explainer, explanation)
# `HeteroDummyModel` returns `x_dict['paper']` and the explainer applies
# object-level masks which uniformly scale each node's feature row, so the
# predicted class is preserved and both fidelity scores collapse to zero.
assert pos_fidelity == 0.0 and neg_fidelity == 0.0


def test_characterization_score():
out = characterization_score(
pos_fidelity=torch.tensor([1.0, 0.6, 0.5, 1.0]),
Expand Down
51 changes: 48 additions & 3 deletions torch_geometric/explain/metric/faithfulness.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn.functional as F

from torch_geometric.explain import Explainer, Explanation
from torch_geometric.explain import Explainer, Explanation, HeteroExplanation
from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType
from torch_geometric.explain.metric.fidelity import _hetero_model_kwargs


def unfaithfulness(
explainer: Explainer,
explanation: Explanation,
explanation: Union[Explanation, HeteroExplanation],
top_k: Optional[int] = None,
) -> float:
r"""Evaluates how faithful an :class:`~torch_geometric.explain.Explanation`
Expand Down Expand Up @@ -43,6 +44,9 @@ def unfaithfulness(
raise ValueError("Cannot apply top-k feature selection based on a "
"node mask of type 'object'")

if isinstance(explanation, HeteroExplanation):
return _hetero_unfaithfulness(explainer, explanation, top_k)

node_mask = explanation.get('node_mask')
edge_mask = explanation.get('edge_mask')
x, edge_index = explanation.x, explanation.edge_index
Expand Down Expand Up @@ -71,3 +75,44 @@ def unfaithfulness(

kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean')
return 1 - float(torch.exp(-kl_div))


def _hetero_unfaithfulness(
explainer: Explainer,
explanation: HeteroExplanation,
top_k: Optional[int] = None,
) -> float:
node_mask = explanation.collect('node_mask', allow_empty=True) or None
edge_mask = explanation.collect('edge_mask', allow_empty=True) or None
x = explanation.x_dict
edge_index = explanation.edge_index_dict
kwargs = _hetero_model_kwargs(explanation)

y = explanation.get('prediction')
if y is None: # == ExplanationType.phenomenon
y = explainer.get_prediction(x, edge_index, **kwargs)

if node_mask is not None and top_k is not None:
new_node_mask = {}
for node_type, mask in node_mask.items():
feat_importance = mask.sum(dim=0)
k = min(top_k, feat_importance.size(0))
_, top_k_index = feat_importance.topk(k)
new_mask = torch.zeros_like(mask)
new_mask[:, top_k_index] = 1.0
new_node_mask[node_type] = new_mask
node_mask = new_node_mask

y_hat = explainer.get_masked_prediction(x, edge_index, node_mask,
edge_mask, **kwargs)

if explanation.get('index') is not None:
y, y_hat = y[explanation.index], y_hat[explanation.index]

if explainer.model_config.return_type == ModelReturnType.raw:
y, y_hat = y.softmax(dim=-1), y_hat.softmax(dim=-1)
elif explainer.model_config.return_type == ModelReturnType.log_probs:
y, y_hat = y.exp(), y_hat.exp()

kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean')
return 1 - float(torch.exp(-kl_div))
82 changes: 79 additions & 3 deletions torch_geometric/explain/metric/fidelity.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
from typing import Tuple
from typing import Any, Dict, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explainer, Explanation
from torch_geometric.explain import Explainer, Explanation, HeteroExplanation
from torch_geometric.explain.config import ExplanationType, ModelMode


def _hetero_model_kwargs(explanation: HeteroExplanation) -> Dict[str, Any]:
# Reconstructs the keyword arguments originally passed to the model.
# Falls back to an empty dict when `_model_args` was not populated
# (e.g. by an older `Explainer`); see PR #10672.
kwargs: Dict[str, Any] = {}
for key in getattr(explanation, '_model_args', []):
if key.endswith('_dict'):
kwargs[key] = explanation.collect(key[:-5], allow_empty=True)
else:
kwargs[key] = explanation[key]
return kwargs


def fidelity(
explainer: Explainer,
explanation: Explanation,
explanation: Union[Explanation, HeteroExplanation],
) -> Tuple[float, float]:
r"""Evaluates the fidelity of an
:class:`~torch_geometric.explain.Explainer` given an
Expand Down Expand Up @@ -50,6 +63,9 @@ def fidelity(
if explainer.model_config.mode == ModelMode.regression:
raise ValueError("Fidelity not defined for 'regression' models")

if isinstance(explanation, HeteroExplanation):
return _hetero_fidelity(explainer, explanation)

node_mask = explanation.get('node_mask')
edge_mask = explanation.get('edge_mask')
kwargs = {key: explanation[key] for key in explanation._model_args}
Expand Down Expand Up @@ -100,6 +116,66 @@ def fidelity(
return float(pos_fidelity), float(neg_fidelity)


def _hetero_fidelity(
explainer: Explainer,
explanation: HeteroExplanation,
) -> Tuple[float, float]:
node_mask = explanation.collect('node_mask', allow_empty=True) or None
edge_mask = explanation.collect('edge_mask', allow_empty=True) or None
kwargs = _hetero_model_kwargs(explanation)

y = explanation.target
if explainer.explanation_type == ExplanationType.phenomenon:
y_hat = explainer.get_prediction(
explanation.x_dict,
explanation.edge_index_dict,
**kwargs,
)
y_hat = explainer.get_target(y_hat)

explain_y_hat = explainer.get_masked_prediction(
explanation.x_dict,
explanation.edge_index_dict,
node_mask,
edge_mask,
**kwargs,
)
explain_y_hat = explainer.get_target(explain_y_hat)

complement_y_hat = explainer.get_masked_prediction(
explanation.x_dict,
explanation.edge_index_dict,
{
k: 1. - v
for k, v in node_mask.items()
} if node_mask is not None else None,
{
k: 1. - v
for k, v in edge_mask.items()
} if edge_mask is not None else None,
**kwargs,
)
complement_y_hat = explainer.get_target(complement_y_hat)

if explanation.get('index') is not None:
y = y[explanation.index]
if explainer.explanation_type == ExplanationType.phenomenon:
y_hat = y_hat[explanation.index]
explain_y_hat = explain_y_hat[explanation.index]
complement_y_hat = complement_y_hat[explanation.index]

if explainer.explanation_type == ExplanationType.model:
pos_fidelity = 1. - (complement_y_hat == y).float().mean()
neg_fidelity = 1. - (explain_y_hat == y).float().mean()
else:
pos_fidelity = ((y_hat == y).float() -
(complement_y_hat == y).float()).abs().mean()
neg_fidelity = ((y_hat == y).float() -
(explain_y_hat == y).float()).abs().mean()

return float(pos_fidelity), float(neg_fidelity)


def characterization_score(
pos_fidelity: Tensor,
neg_fidelity: Tensor,
Expand Down