diff --git a/baselines/ViT/generate_visualizations.py b/baselines/ViT/generate_visualizations.py index fdde002..e667db5 100644 --- a/baselines/ViT/generate_visualizations.py +++ b/baselines/ViT/generate_visualizations.py @@ -72,11 +72,11 @@ def compute_saliency_and_save(args): # Res = Res - Res.mean() elif args.method == 'lrp': - Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14) + Res = orig_lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'transformer_attribution': - Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14) + Res = lrp.generate_LRP(data, start_layer=1, method="transformer_attribution", index=index).reshape(data.shape[0], 1, 14, 14) # Res = Res - Res.mean() elif args.method == 'full_lrp':