@@ -217,19 +217,17 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
217217
218218 if match_top_k :
219219 metric_abbr = match_top_k .group (1 ) # P, R, PR, or nDCG
220- top_k = int (match_top_k .group (2 ))
221- if top_k >= num_classes :
222- raise ValueError (
223- f"Invalid metric: { metric } . top_k ({ top_k } ) is greater than num_classes({ num_classes } )."
224- )
220+ k = int (match_top_k .group (2 ))
221+ if k >= num_classes :
222+ raise ValueError (f"Invalid metric: { metric } . k ({ k } ) is greater than num_classes({ num_classes } )." )
225223 if metric_abbr == "P" :
226- metrics [metric ] = Precision (num_classes , average = "samples" , top_k = top_k )
224+ metrics [metric ] = Precision (num_classes , average = "samples" , top_k = k )
227225 elif metric_abbr == "R" :
228- metrics [metric ] = Recall (num_classes , average = "samples" , top_k = top_k )
226+ metrics [metric ] = Recall (num_classes , average = "samples" , top_k = k )
229227 elif metric_abbr == "RP" :
230- metrics [metric ] = RPrecision (top_k = top_k )
228+ metrics [metric ] = RPrecision (top_k = k )
231229 elif metric_abbr == "nDCG" :
232- metrics [metric ] = NDCG (top_k = top_k )
230+ metrics [metric ] = NDCG (top_k = k )
233231 # The implementation in torchmetrics stores the prediction/target of all batches,
234232 # which can lead to CUDA out of memory.
235233 # metrics[metric] = RetrievalNormalizedDCG(k=top_k)
0 commit comments