55#
66# Contact wanghuijie@pjlab.org.cn if you have any issue.
77#
8- # Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
8+ # Copyright (c) 2023 The OpenLane-V2 Dataset Authors. All Rights Reserved.
99#
1010# Licensed under the Apache License, Version 2.0 (the "License");
1111# you may not use this file except in compliance with the License.
@@ -105,7 +105,7 @@ def _tpfp(gts, preds, confidences, distance_matrix, distance_threshold):
105105
106106 tp = np .zeros ((num_preds ), dtype = np .float32 )
107107 fp = np .zeros ((num_preds ), dtype = np .float32 )
108- idx_match_gt = np .ones ((num_preds )) * np .nan
108+ idx_match_gt = np .ones ((num_preds ), dtype = int ) * np .nan
109109
110110 if num_gts == 0 :
111111 fp [...] = 1
@@ -434,29 +434,30 @@ def _mAP_topology_lclc(gts, preds, distance_thresholds):
434434
435435 """
436436 acc = []
437- for r in range (10 ):
438- for distance_threshold in distance_thresholds :
439- for token in gts .keys ():
440- preds_topology_lclc_unmatched = preds [token ]['topology_lclc' ]
437+ for distance_threshold in distance_thresholds :
438+ for token in gts .keys ():
439+ preds_topology_lclc_unmatched = preds [token ]['topology_lclc' ]
441440
442- idx_match_gt = preds [token ][f'lane_centerline_{ distance_threshold } _idx_match_gt' ]
443- confidence = preds [token ][f'lane_centerline_{ distance_threshold } _confidence' ]
444- confidence_thresholds = preds [token ][f'lane_centerline_{ distance_threshold } _confidence_thresholds' ]
445- gt_pred = {m : i for i , (m , c ) in enumerate (zip (idx_match_gt , confidence )) if c >= confidence_thresholds [r ] and not np .isnan (m )}
441+ idx_match_gt = preds [token ][f'lane_centerline_{ distance_threshold } _idx_match_gt' ]
442+ gt_pred = {m : i for i , m in enumerate (idx_match_gt ) if not np .isnan (m )}
446443
447- gts_topology_lclc = gts [token ]['topology_lclc' ]
448- if 0 in gts_topology_lclc .shape :
449- continue
444+ gts_topology_lclc = gts [token ]['topology_lclc' ]
445+ if 0 in gts_topology_lclc .shape :
446+ continue
450447
451- preds_topology_lclc = np .ones_like (gts_topology_lclc , dtype = gts_topology_lclc .dtype ) * np .nan
452- for i in range (preds_topology_lclc .shape [0 ]):
453- for j in range (preds_topology_lclc .shape [1 ]):
454- if i in gt_pred and j in gt_pred :
455- preds_topology_lclc [i ][j ] = preds_topology_lclc_unmatched [gt_pred [i ]][gt_pred [j ]]
456- preds_topology_lclc [np .isnan (preds_topology_lclc )] = 1 - gts_topology_lclc [np .isnan (preds_topology_lclc )]
448+ gt_indices = np .array (list (gt_pred .keys ())).astype (int )
449+ pred_indices = np .array (list (gt_pred .values ())).astype (int )
450+ preds_topology_lclc = np .ones_like (gts_topology_lclc , dtype = gts_topology_lclc .dtype ) * np .nan
451+ xs = gt_indices [:, None ].repeat (len (gt_indices ), 1 )
452+ ys = gt_indices [None , :].repeat (len (gt_indices ), 0 )
453+ preds_topology_lclc [xs , ys ] = preds_topology_lclc_unmatched [pred_indices ][:, pred_indices ]
454+ preds_topology_lclc [np .isnan (preds_topology_lclc )] = (
455+ 1 - gts_topology_lclc [np .isnan (preds_topology_lclc )]) * (0.5 + np .finfo (np .float32 ).eps )
457456
458- acc .append (_AP_directerd (gts = gts_topology_lclc , preds = preds_topology_lclc ))
457+ acc .append (_AP_directerd (gts = gts_topology_lclc , preds = preds_topology_lclc ))
459458
459+ if len (acc ) == 0 :
460+ return np .float32 (0 )
460461 return np .hstack (acc ).mean ()
461462
462463def _mAP_topology_lcte (gts , preds , distance_thresholds ):
@@ -479,41 +480,37 @@ def _mAP_topology_lcte(gts, preds, distance_thresholds):
479480
480481 """
481482 acc = []
482- for r in range (10 ):
483- for distance_threshold_lane_centerline in distance_thresholds ['lane_centerline' ]:
484- for distance_threshold_traffic_element in distance_thresholds ['traffic_element' ]:
485- for token in gts .keys ():
486- preds_topology_lcte_unmatched = preds [token ]['topology_lcte' ]
487-
488- idx_match_gt_lane_centerline = preds [token ][f'lane_centerline_{ distance_threshold_lane_centerline } _idx_match_gt' ]
489- confidence_lane_centerline = preds [token ][f'lane_centerline_{ distance_threshold_lane_centerline } _confidence' ]
490- confidence_thresholds_lane_centerline = preds [token ][f'lane_centerline_{ distance_threshold_lane_centerline } _confidence_thresholds' ]
491- gt_pred_lane_centerline = {
492- m : i for i , (m , c ) in enumerate (zip (idx_match_gt_lane_centerline , confidence_lane_centerline )) \
493- if c >= confidence_thresholds_lane_centerline [r ] and not np .isnan (m )
494- }
495-
496- idx_match_gt_traffic_element = preds [token ][f'traffic_element_{ distance_threshold_traffic_element } _idx_match_gt' ]
497- confidence_traffic_element = preds [token ][f'traffic_element_{ distance_threshold_traffic_element } _confidence' ]
498- confidence_thresholds_traffic_element = preds [token ][f'traffic_element_{ distance_threshold_traffic_element } _confidence_thresholds' ]
499- gt_pred_traffic_element = {
500- m : i for i , (m , c ) in enumerate (zip (idx_match_gt_traffic_element , confidence_traffic_element )) \
501- if c >= confidence_thresholds_traffic_element [r ] and not np .isnan (m )
502- }
503-
504- gts_topology_lcte = gts [token ]['topology_lcte' ]
505- if 0 in gts_topology_lcte .shape :
506- continue
507-
508- preds_topology_lcte = np .ones_like (gts_topology_lcte , dtype = gts_topology_lcte .dtype ) * np .nan
509- for i in range (preds_topology_lcte .shape [0 ]):
510- for j in range (preds_topology_lcte .shape [1 ]):
511- if i in gt_pred_lane_centerline and j in gt_pred_traffic_element :
512- preds_topology_lcte [i ][j ] = preds_topology_lcte_unmatched [gt_pred_lane_centerline [i ]][gt_pred_traffic_element [j ]]
513- preds_topology_lcte [np .isnan (preds_topology_lcte )] = 1 - gts_topology_lcte [np .isnan (preds_topology_lcte )]
514-
515- acc .append (_AP_undirecterd (gts = gts_topology_lcte , preds = preds_topology_lcte ))
516-
483+ for distance_threshold_lane_centerline in distance_thresholds ['lane_centerline' ]:
484+ for distance_threshold_traffic_element in distance_thresholds ['traffic_element' ]:
485+ for token in gts .keys ():
486+ preds_topology_lcte_unmatched = preds [token ]['topology_lcte' ]
487+
488+ idx_match_gt_lane_centerline = preds [token ][f'lane_centerline_{ distance_threshold_lane_centerline } _idx_match_gt' ]
489+ gt_pred_lane_centerline = {m : i for i , m in enumerate (idx_match_gt_lane_centerline ) if not np .isnan (m )}
490+
491+ idx_match_gt_traffic_element = preds [token ][f'traffic_element_{ distance_threshold_traffic_element } _idx_match_gt' ]
492+ gt_pred_traffic_element = {m : i for i , m in enumerate (idx_match_gt_traffic_element ) if not np .isnan (m )}
493+
494+ gts_topology_lcte = gts [token ]['topology_lcte' ]
495+ if 0 in gts_topology_lcte .shape :
496+ continue
497+
498+ gt_indices_lc = np .array (list (gt_pred_lane_centerline .keys ())).astype (int )
499+ pred_indices_lc = np .array (list (gt_pred_lane_centerline .values ())).astype (int )
500+ gt_indices_te = np .array (list (gt_pred_traffic_element .keys ())).astype (int )
501+ pred_indices_te = np .array (list (gt_pred_traffic_element .values ())).astype (int )
502+
503+ preds_topology_lcte = np .ones_like (gts_topology_lcte , dtype = gts_topology_lcte .dtype ) * np .nan
504+ xs = gt_indices_lc [:, None ].repeat (len (gt_indices_te ), 1 )
505+ ys = gt_indices_te [None , :].repeat (len (gt_indices_lc ), 0 )
506+ preds_topology_lcte [xs , ys ] = preds_topology_lcte_unmatched [pred_indices_lc ][:, pred_indices_te ]
507+ preds_topology_lcte [np .isnan (preds_topology_lcte )] = (
508+ 1 - gts_topology_lcte [np .isnan (preds_topology_lcte )]) * (0.5 + np .finfo (np .float32 ).eps )
509+
510+ acc .append (_AP_undirecterd (gts = gts_topology_lcte , preds = preds_topology_lcte ))
511+
512+ if len (acc ) == 0 :
513+ return np .float32 (0 )
517514 return np .hstack (acc ).mean ()
518515
519516def evaluate (ground_truth , predictions , verbose = True ):
0 commit comments