Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 215 additions & 34 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import requests
import typer
import yaml
from rich import print
from rich.markup import escape

from comfy_cli import constants, tracking, ui
from comfy_cli.config_manager import ConfigManager
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.extra_model_paths import collect_extra_paths, paths_for_category
from comfy_cli.file_utils import DownloadException, check_unauthorized, download_file
from comfy_cli.workspace_manager import WorkspaceManager

Expand All @@ -37,6 +39,21 @@ def get_workspace() -> pathlib.Path:
return pathlib.Path(workspace_manager.workspace_path)


def _resolve_default_relative_path(category: str | None, basemodel: str, extras: list) -> str:
"""Pick the destination subdir for a typed download.

Returns an absolute path string when ``extras`` configures the category
(pathlib's ``/`` operator discards the workspace prefix in that case).
Otherwise returns the workspace-relative ``models/<category>/<basemodel>``
form preserved from comfy-cli's existing behavior.
"""
if category and extras:
configured = paths_for_category(extras, category)
if configured:
return str(configured[0] / basemodel) if basemodel else str(configured[0])
return os.path.join(DEFAULT_COMFY_MODEL_PATH, category or "", basemodel)


def _format_elapsed(seconds: float) -> str:
"""Format elapsed seconds into a human-readable string."""
rounded = round(seconds, 1)
Expand Down Expand Up @@ -243,10 +260,33 @@ def download(
show_default=False,
),
] = None,
extra_model_paths_config: Annotated[
list[pathlib.Path] | None,
typer.Option(
"--extra-model-paths-config",
help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.",
show_default=False,
),
] = None,
extra_model_paths: Annotated[
bool,
typer.Option(
"--extra-model-paths/--no-extra-model-paths",
help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.",
show_default=False,
),
] = True,
):
if relative_path is not None:
relative_path = os.path.expanduser(relative_path)

extras: list = []
if extra_model_paths:
try:
extras = collect_extra_paths(get_workspace(), extra_model_paths_config or [])
except yaml.YAMLError as e:
print(f"[yellow]Warning: extra_model_paths YAML is invalid; ignoring extras ({escape(str(e))})[/yellow]")

local_filename = None
headers = None

Expand Down Expand Up @@ -278,7 +318,7 @@ def download(
if model_path is None:
model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="")

relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel)
relative_path = _resolve_default_relative_path(model_path, basemodel, extras)
elif is_civitai_api_url:
local_filename, url, model_type, basemodel = request_civitai_model_version_api(version_id, headers)

Expand All @@ -288,7 +328,7 @@ def download(
if model_path is None:
model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="")

relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel)
relative_path = _resolve_default_relative_path(model_path, basemodel, extras)
elif is_huggingface_url:
model_id = "/".join(url.split("/")[-2:])

Expand All @@ -297,7 +337,7 @@ def download(
if relative_path is None:
model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="")
basemodel = ui.prompt_input("Enter base model (e.g. SD1.5, SDXL, ...)", default="")
relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel)
relative_path = _resolve_default_relative_path(model_path, basemodel, extras)
else:
print("Model source is unknown")

Expand Down Expand Up @@ -388,47 +428,93 @@ def remove(
help="Confirm for deletion and skip the prompt",
show_default=False,
),
extra_model_paths_config: list[pathlib.Path] | None = typer.Option(
None,
"--extra-model-paths-config",
help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.",
show_default=False,
),
extra_model_paths: bool = typer.Option(
True,
"--extra-model-paths/--no-extra-model-paths",
help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.",
show_default=False,
),
):
"""Remove one or more downloaded models, either by specifying them directly or through an interactive selection."""
model_dir = get_workspace() / relative_path
available_models = list_models(model_dir)
primary = get_workspace() / relative_path
extras = _load_extras_safely(extra_model_paths, extra_model_paths_config)
roots = _enumerate_search_roots(primary, extras)
scanned = _scan_all_roots(roots)

if not available_models:
if not scanned:
typer.echo("No models found to remove.")
return

model_dir_resolved = model_dir.resolve()
resolved_roots: list[pathlib.Path] = []
for root, _ in roots:
try:
resolved_roots.append(root.resolve())
except OSError:
continue

to_delete = []
# Scenario #1: User provided model names to delete
to_delete: list[pathlib.Path] = []
if model_names:
# Validate and filter models to delete based on provided names
missing_models = []
ambiguous = []
for name in model_names:
model_path = (model_dir / name).resolve()
if not model_path.is_relative_to(model_dir_resolved):
typer.echo(f"Invalid model path: {name}")
valid_matches: set[pathlib.Path] = set()
any_outside = False
for root, _ in roots:
try:
candidate = (root / name).resolve()
except OSError:
continue
if any(candidate.is_relative_to(r) for r in resolved_roots):
if candidate.is_file():
valid_matches.add(candidate)
else:
any_outside = True

if not valid_matches:
if any_outside:
typer.echo(f"Invalid model path: {name}")
else:
missing_models.append(name)
continue
if model_path.is_file():
to_delete.append(model_path)
else:
missing_models.append(name)

if len(valid_matches) > 1:
ambiguous.append((name, sorted(valid_matches)))
continue

to_delete.append(valid_matches.pop())

if ambiguous:
for name, paths in ambiguous:
typer.echo(f"Ambiguous model name '{name}'; matches multiple paths:")
for p in paths:
typer.echo(f" {p}")
typer.echo("Specify a more specific path to disambiguate.")
if not to_delete:
return

if missing_models:
typer.echo("The following models were not found and cannot be removed: " + ", ".join(missing_models))
if not to_delete:
return # Exit if no valid models were found

# Scenario #2: User did not provide model names, prompt for selection
return
else:
rel_names = [str(model.relative_to(model_dir)) for model in available_models]
selections = ui.prompt_multi_select("Select models to delete:", rel_names)
if len(roots) == 1:
single_root = roots[0][0]
labels_to_paths = {str(file.relative_to(single_root)): file for file, _, _ in scanned}
else:
labels_to_paths = {str(file): file for file, _, _ in scanned}

selections = ui.prompt_multi_select("Select models to delete:", list(labels_to_paths.keys()))
if not selections:
typer.echo("No models selected for deletion.")
return
to_delete = [model_dir / selection for selection in selections]
to_delete = [labels_to_paths[sel] for sel in selections]

# Confirm deletion
if to_delete and (
confirm or ui.prompt_confirm_action("Are you sure you want to delete the selected files?", False)
):
Expand All @@ -446,6 +532,81 @@ def list_models(path: pathlib.Path) -> list[pathlib.Path]:
return sorted(f for f in path.rglob("*") if f.is_file())


def _load_extras_safely(use_extras: bool, extra_configs: list[pathlib.Path] | None) -> list:
if not use_extras:
return []
try:
return collect_extra_paths(get_workspace(), extra_configs or [])
except yaml.YAMLError as e:
print(f"[yellow]Warning: extra_model_paths YAML is invalid; ignoring extras ({escape(str(e))})[/yellow]")
return []


def _enumerate_search_roots(primary_root: pathlib.Path, extras: list) -> list[tuple[pathlib.Path, str | None]]:
"""Return ``(root, category)`` pairs to scan, longest-first.

The primary root carries ``category=None`` so list rendering preserves
today's "category from path" behavior. Extras roots carry their canonical
category name for the Type-column prefix. Roots are deduplicated by
realpath; unresolvable roots (e.g., circular symlinks) are skipped with
a warning. Sorting longest-first ensures a file under nested roots is
assigned to the most specific one.
"""
candidates: list[tuple[pathlib.Path, str | None]] = [(primary_root, None)]
for ep in extras:
candidates.append((ep.path, ep.category))

seen_resolved: set[pathlib.Path] = set()
unique: list[tuple[pathlib.Path, str | None]] = []
for root, category in candidates:
try:
resolved = root.resolve()
except OSError as e:
print(f"[yellow]Warning: skipping {root}: {e}[/yellow]")
continue
if resolved in seen_resolved:
continue
seen_resolved.add(resolved)
unique.append((root, category))

unique.sort(key=lambda rc: len(rc[0].parts), reverse=True)
return unique


def _scan_all_roots(
roots: list[tuple[pathlib.Path, str | None]],
) -> list[tuple[pathlib.Path, pathlib.Path, str | None]]:
"""Return ``(file, root, category)`` tuples, each file assigned to its
deepest containing root. Output is sorted by file path."""
seen_files: set[pathlib.Path] = set()
result: list[tuple[pathlib.Path, pathlib.Path, str | None]] = []
for root, category in roots:
for file in list_models(root):
try:
resolved = file.resolve()
except OSError:
continue
if resolved in seen_files:
continue
seen_files.add(resolved)
result.append((file, root, category))
result.sort(key=lambda x: x[0])
return result


def _format_type_column(file: pathlib.Path, root: pathlib.Path, category: str | None) -> str:
"""Compute Type column text. For extras roots the canonical category is
prepended so output is consistent with the workspace listing where the
category is implicit in the on-disk subdir."""
rel = file.relative_to(root)
parent = str(rel.parent) if len(rel.parts) > 1 else ""
if category is None:
return parent
if not parent or parent == ".":
return category
return f"{category}/{parent}"


@app.command("list")
@tracking.track_command("model")
def list_command(
Expand All @@ -455,20 +616,40 @@ def list_command(
help="The relative path from the current workspace where the models are stored.",
show_default=True,
),
extra_model_paths_config: list[pathlib.Path] | None = typer.Option(
None,
"--extra-model-paths-config",
help="Additional extra_model_paths.yaml file(s) to honor. Repeatable.",
show_default=False,
),
extra_model_paths: bool = typer.Option(
True,
"--extra-model-paths/--no-extra-model-paths",
help="Honor extra_model_paths.yaml from the workspace and any --extra-model-paths-config files.",
show_default=False,
),
):
"""Display a list of all models currently downloaded in a table format."""
model_dir = get_workspace() / relative_path
models = list_models(model_dir)
primary = get_workspace() / relative_path
extras = _load_extras_safely(extra_model_paths, extra_model_paths_config)
roots = _enumerate_search_roots(primary, extras)
scanned = _scan_all_roots(roots)

if not models:
if not scanned:
typer.echo("No models found.")
return

# Prepare data for table display
show_source = len({r for _, r, _ in scanned}) > 1
data = []
for model in models:
rel = model.relative_to(model_dir)
model_type = str(rel.parent) if len(rel.parts) > 1 else ""
data.append((model.name, model_type, f"{model.stat().st_size // 1024} KB"))
column_names = ["Model Name", "Type", "Size"]
ui.display_table(data, column_names)
for file, root, category in scanned:
type_str = _format_type_column(file, root, category)
size_str = f"{file.stat().st_size // 1024} KB"
if show_source:
data.append((file.name, type_str, size_str, str(root)))
else:
data.append((file.name, type_str, size_str))

columns = ["Model Name", "Type", "Size"]
if show_source:
columns.append("Source")
ui.display_table(data, columns)
Loading
Loading