Skip to content

Commit df52157

Browse files
authored
Merge pull request #356 from ntumlgroup/fix_metric_k
Fix top_k in metrics.py
2 parents a7ec069 + 72cd852 commit df52157

2 files changed

Lines changed: 8 additions & 10 deletions

File tree

libmultilabel/nn/metrics.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,17 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
217217

218218
if match_top_k:
219219
metric_abbr = match_top_k.group(1) # P, R, PR, or nDCG
220-
top_k = int(match_top_k.group(2))
221-
if top_k >= num_classes:
222-
raise ValueError(
223-
f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})."
224-
)
220+
k = int(match_top_k.group(2))
221+
if k >= num_classes:
222+
raise ValueError(f"Invalid metric: {metric}. k ({k}) is greater than num_classes({num_classes}).")
225223
if metric_abbr == "P":
226-
metrics[metric] = Precision(num_classes, average="samples", top_k=top_k)
224+
metrics[metric] = Precision(num_classes, average="samples", top_k=k)
227225
elif metric_abbr == "R":
228-
metrics[metric] = Recall(num_classes, average="samples", top_k=top_k)
226+
metrics[metric] = Recall(num_classes, average="samples", top_k=k)
229227
elif metric_abbr == "RP":
230-
metrics[metric] = RPrecision(top_k=top_k)
228+
metrics[metric] = RPrecision(top_k=k)
231229
elif metric_abbr == "nDCG":
232-
metrics[metric] = NDCG(top_k=top_k)
230+
metrics[metric] = NDCG(top_k=k)
233231
# The implementation in torchmetrics stores the prediction/target of all batches,
234232
# which can lead to CUDA out of memory.
235233
# metrics[metric] = RetrievalNormalizedDCG(k=top_k)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = libmultilabel
3-
version = 0.5.1
3+
version = 0.5.2
44
author = LibMultiLabel Team
55
license = MIT License
66
license_file = LICENSE

0 commit comments

Comments
 (0)