|
| 1 | +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license |
| 2 | + |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +from cog import BaseModel, BasePredictor, Input, Path |
| 6 | +from ultralytics import YOLOWorld |
| 7 | + |
| 8 | + |
| 9 | +class Output(BaseModel): |
| 10 | + """Output model for predictions.""" |
| 11 | + |
| 12 | + image: Optional[Path] = None |
| 13 | + json_str: Optional[str] = None |
| 14 | + |
| 15 | + |
| 16 | +class Predictor(BasePredictor): |
| 17 | + """YOLOv8s WorldV2 model predictor for Replicate deployment.""" |
| 18 | + |
| 19 | + def setup(self) -> None: |
| 20 | + """Load YOLOWorld model into memory.""" |
| 21 | + self.model = YOLOWorld("yolov8s-worldv2.pt") |
| 22 | + |
| 23 | + def re_init_model(self, class_names: str) -> None: |
| 24 | + """Re-Initialize model with class names.""" |
| 25 | + self.model = YOLOWorld("yolov8s-worldv2.pt") |
| 26 | + class_list = class_names.split(", ") |
| 27 | + self.model.set_classes(class_list) |
| 28 | + |
| 29 | + def predict( |
| 30 | + self, |
| 31 | + image: Path = Input(description="Input image"), |
| 32 | + conf: float = Input(description="Confidence threshold", default=0.25, ge=0.0, le=1.0), |
| 33 | + iou: float = Input(description="IoU threshold for NMS", default=0.45, ge=0.0, le=1.0), |
| 34 | + imgsz: int = Input(description="Image size", default=640, choices=[320, 416, 512, 640, 832, 1024, 1280]), |
| 35 | + class_names: str = Input( |
| 36 | + description="Comma-separated list of class names to filter results (e.g., 'person, bus, sign') You can also leave it empty to detect classes automatically.", |
| 37 | + default="person, bus, sign", |
| 38 | + ), |
| 39 | + return_json: bool = Input(description="Return detection results as JSON", default=False), |
| 40 | + ) -> Output: |
| 41 | + """Run inference and return annotated image with optional JSON results.""" |
| 42 | + self.re_init_model(class_names) |
| 43 | + result = self.model(str(image), conf=conf, iou=iou, imgsz=imgsz)[0] |
| 44 | + image_path = "output.png" |
| 45 | + result.save(image_path) |
| 46 | + |
| 47 | + if return_json: |
| 48 | + return Output(image=Path(image_path), json_str=result.to_json()) |
| 49 | + else: |
| 50 | + return Output(image=Path(image_path)) |
0 commit comments