@@ -580,32 +580,43 @@ def train_eval(
580580 )
581581
582582 # Applying masking (removing pixels where gt == -1)
583- unmasked_label = label != - 1
584- label = label [unmasked_label ]
583+ if task_type == "segmentation" :
584+ unmasked_label = [l != - 1 for l in label ]
585+ label = [l [u ] for l , u in zip (label , unmasked_label )]
586+ else :
587+ label = label .view (- 1 )
585588 loss = 0
586589 for i in range (len (outputs )):
587590 output = outputs [i ]
588591 out = []
589- for c in range (output .shape [1 ]):
590- out .append (output [:, c ][unmasked_label ].unsqueeze (- 1 ))
591- out = torch .cat (out , dim = - 1 )
592-
593- # Compute loss
594- curr_loss = criterion (out , label )
592+ if task_type == "segmentation" :
593+ for o , m in zip (output , unmasked_label ):
594+ out .append (torch .cat ([o [c ][m ].unsqueeze (- 1 ) for c in range (o .shape [0 ])], dim = - 1 ))
595+ curr_loss = sum ([criterion (o , l ) for o , l in zip (out , label )]) / len (out )
596+ else :
597+ for c in range (output .shape [1 ]):
598+ out .append (output [:, c ].unsqueeze (- 1 ))
599+ out = torch .cat (out , dim = - 1 )
600+ curr_loss = criterion (out , label )
595601 loss += curr_loss
596602
597603 # Logging
598604 if batch_id == 0 :
599605 tot_loss .append ([curr_loss .item ()])
600606 if comp_metrics :
601- all_out .append ([out .detach ().cpu ()])
607+ all_out .append ([[o .detach ().cpu () for o in out ]] if task_type == 'segmentation'
608+ else [out .detach ().cpu ()])
602609 else :
603610 tot_loss [i ].append (curr_loss .item ())
604611 if comp_metrics :
605- all_out [i ].append (out .detach ().cpu ())
612+ all_out [i ].append ([o .detach ().cpu () for o in out ] if task_type == 'segmentation'
613+ else out .detach ().cpu ())
606614 # Logging
607615 if comp_metrics :
608- all_label .append (label .cpu ())
616+ if task_type == "segmentation" :
617+ all_label .extend ([l .cpu () for l in label ])
618+ else :
619+ all_label .append (label .cpu ())
609620
610621 if run_type == "train" :
611622 # Compute gradients
@@ -627,18 +638,27 @@ def train_eval(
627638 viz_im = None
628639
629640 if comp_metrics :
630- # Computing metrics
631- all_label = torch .cat (all_label )
632- metrics = []
633- for i in range (len (all_out )):
634- all_out [i ] = torch .cat (all_out [i ])
635- all_out [i ] = F .softmax (all_out [i ], dim = 1 )
636- classification_metrics = compute_metrics (all_out [i ], None , all_label )
637- conformal_metrics = compute_calibration_metrics (all_out [i ], all_label )
638- curr_metrics = (
639- classification_metrics | conformal_metrics
640- ) # merging dictionaries
641- metrics .append (curr_metrics )
641+ if task_type == "segmentation" :
642+ metrics = []
643+ for i in range (len (all_out )):
644+ all_out [i ] = [F .softmax (item , dim = 1 ) for batch in all_out [i ] for item in batch ]
645+ all_metrics = [compute_metrics (o , None , l , True ) for o , l in zip (all_out [i ], all_label ) if len (l ) > 0 ]
646+ weights = [len (l ) for l in all_label if len (l ) > 0 ]
647+ metrics .append ({key : np .average ([d [key ] for d in all_metrics ], weights = weights ) for key in all_metrics [0 ]} |
648+ {f'{ key } _per_sample' : [d [key ] for d in all_metrics ] for key in all_metrics [0 ]})
649+ else :
650+ # Computing metrics
651+ all_label = torch .cat (all_label )
652+ metrics = []
653+ for i in range (len (all_out )):
654+ all_out [i ] = torch .cat (all_out [i ])
655+ all_out [i ] = F .softmax (all_out [i ], dim = 1 )
656+ classification_metrics = compute_metrics (all_out [i ], None , all_label )
657+ conformal_metrics = compute_calibration_metrics (all_out [i ], all_label )
658+ curr_metrics = (
659+ classification_metrics | conformal_metrics
660+ ) # merging dictionaries
661+ metrics .append (curr_metrics )
642662 else :
643663 metrics = None
644664
0 commit comments