Skip to content

Commit 899ea46

Browse files
committed
add custom metrics to AttackLogManager through AttackArgs
1 parent 5ce8e26 commit 899ea46

3 files changed

Lines changed: 17 additions & 4 deletions

File tree

textattack/attack_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import sys
1010
import time
11+
from typing import Dict, Optional
1112

1213
import textattack
1314
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
@@ -207,6 +208,7 @@ class AttackArgs:
207208
disable_stdout: bool = False
208209
silent: bool = False
209210
enable_advance_metrics: bool = False
211+
metrics: Optional[Dict] = None
210212

211213
def __post_init__(self):
212214
if self.num_successful_examples:
@@ -386,12 +388,13 @@ def _add_parser_args(cls, parser):
386388

387389
@classmethod
388390
def create_loggers_from_args(cls, args):
391+
"""Creates AttackLogManager from an AttackArgs object."""
389392
assert isinstance(
390393
args, cls
391394
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
392395

393396
# Create logger
394-
attack_log_manager = textattack.loggers.AttackLogManager()
397+
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
395398

396399
# Get current time for file naming
397400
timestamp = time.strftime("%Y-%m-%d-%H-%M")

textattack/loggers/attack_log_manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
========================
44
"""
55

6+
from typing import Dict, Optional
7+
68
from textattack.metrics.attack_metrics import (
79
AttackQueries,
810
AttackSuccessRate,
@@ -22,10 +24,17 @@
2224
class AttackLogManager:
2325
"""Logs the results of an attack to all attached loggers."""
2426

25-
def __init__(self):
27+
# metrics maps strings (metric names) to textattack.metric.Metric objects
28+
metrics: Dict
29+
30+
def __init__(self, metrics: Optional[Dict]):
2631
self.loggers = []
2732
self.results = []
2833
self.enable_advance_metrics = False
34+
if metrics is None:
35+
self.metrics = {}
36+
else:
37+
self.metrics = metrics
2938

3039
def enable_stdout(self):
3140
self.loggers.append(FileLogger(stdout=True))
@@ -127,6 +136,9 @@ def log_summary(self):
127136
["Avg num queries:", attack_query_stats["avg_num_queries"]]
128137
)
129138

139+
for metric_name, metric in self.metrics.items():
140+
summary_table_rows.append([metric_name, metric.calculate(self.results)])
141+
130142
if self.enable_advance_metrics:
131143
perplexity_stats = Perplexity().calculate(self.results)
132144
use_stats = USEMetric().calculate(self.results)

textattack/shared/utils/tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
99
"""
1010
outputs = []
1111
i = 0
12-
# print("batch_model_predict", inputs.shape)
13-
# print("inputs:", inputs)
1412
while i < len(inputs):
1513
batch = inputs[i : i + batch_size]
1614
batch_preds = model_predict(batch)

0 commit comments

Comments
 (0)