11# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
22
33import json
4- from typing import Any , Dict
5-
6- from cog import BasePredictor , Input , Path
4+ from typing import Optional
75from ultralytics import YOLO
6+ from cog import BasePredictor , Input , Path , BaseModel
7+
8+
9+ class Output (BaseModel ):
10+ image : Optional [Path ] = None
11+ json_str : Optional [str ] = None
812
913
1014class Predictor (BasePredictor ):
@@ -21,13 +25,16 @@ def predict(
2125 iou : float = Input (description = "IoU threshold for NMS" , default = 0.45 , ge = 0.0 , le = 1.0 ),
2226 imgsz : int = Input (description = "Image size" , default = 640 , choices = [320 , 416 , 512 , 640 , 832 , 1024 , 1280 ]),
2327 return_json : bool = Input (description = "Return detection results as JSON" , default = False ),
24- ) -> Dict [ str , Any ] | Path :
28+ ) -> Output :
2529 """Run inference and return annotated image with optional JSON results."""
2630 result = self .model (str (image ), conf = conf , iou = iou , imgsz = imgsz )[0 ]
2731 image_path = "output.png"
2832 result .save (image_path )
2933
3034 if return_json :
31- return {"image" : Path (image_path ), "results" : json .loads (result .to_json ())}
35+ return Output (
36+ image = Path (image_path ),
37+ json_str = result .to_json ()
38+ )
3239 else :
33- return Path (image_path )
40+ return Output ( image = Path (image_path ) )
0 commit comments