diff --git a/demo/predictor.py b/demo/predictor.py index 189ec79..b13eca8 100644 --- a/demo/predictor.py +++ b/demo/predictor.py @@ -15,19 +15,21 @@ class VisualizationDemo(object): - def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False, confidence_threshold=0.5): """ Args: cfg (CfgNode): instance_mode (ColorMode): parallel (bool): whether to run the model in different processes from visualization. Useful since the visualization logic can be slow. + confidence_threshold (float): minimum score for instance predictions to be shown """ self.metadata = MetadataCatalog.get( cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" ) self.cpu_device = torch.device("cpu") self.instance_mode = instance_mode + self.confidence_threshold = confidence_threshold self.parallel = parallel if parallel: @@ -60,9 +62,16 @@ def run_on_image(self, image): vis_output = visualizer.draw_sem_seg( predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) ) - if "instances" in predictions: - instances = predictions["instances"].to(self.cpu_device) - vis_output = visualizer.draw_instance_predictions(predictions=instances) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + + # Filter instances by confidence threshold + if self.confidence_threshold > 0: + scores = instances.scores + keep = scores > self.confidence_threshold + instances = instances[keep] + + vis_output = visualizer.draw_instance_predictions(predictions=instances) return predictions, vis_output @@ -94,6 +103,13 @@ def process_predictions(frame, predictions): ) elif "instances" in predictions: predictions = predictions["instances"].to(self.cpu_device) + + # Filter instances by confidence threshold + if self.confidence_threshold > 0: + scores = predictions.scores + keep = scores > self.confidence_threshold + predictions = predictions[keep] + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) elif "sem_seg" in predictions: vis_frame = video_visualizer.draw_sem_seg(