diff --git a/apps/Portal/app.py b/apps/Portal/app.py index 1c47ec0..3bc39ab 100644 --- a/apps/Portal/app.py +++ b/apps/Portal/app.py @@ -15,6 +15,7 @@ import stat import subprocess import sys +import tempfile import time import threading import tomllib @@ -100,6 +101,11 @@ TRAINPILOT_BIN = Path("/opt/pilot/apps/TrainPilot/trainpilot.sh") TRAINPILOT_BUNDLED_TOML = Path("/opt/pilot/apps/TrainPilot/newlora.toml") TRAINPILOT_PERSISTENT_TOML = WORKSPACE_ROOT / "config" / "trainpilot" / "newlora.toml" +_TRAINPILOT_LOCAL_PATH_ROOTS = ( + WORKSPACE_ROOT.resolve(), + Path("/opt").resolve(), + Path(os.environ.get("HOME", "/root")).expanduser().resolve(), +) _tp_proc: Optional[subprocess.Popen] = None _tp_logs: deque[str] = deque(maxlen=4000) _tp_output_dir: Optional[Path] = None @@ -170,15 +176,56 @@ def _update_model_pull_job(job: ModelPullJob, line: str) -> None: pass -def _run_model_pull_job(job: ModelPullJob, cmd: list[str]) -> None: +def _cleanup_temp_paths(paths: list[Path]) -> None: + for path in paths: + try: + path.unlink(missing_ok=True) + except Exception: + pass + + +def _build_model_pull_command(name: str) -> tuple[list[str], dict[str, str], list[Path]]: + manifest_line = models_service.manifest_line_for_name( + name, + MANIFEST, + DEFAULT_MANIFEST, + MODELS_DIR, + CONFIG_DIR, + ) + temp_dir = CONFIG_DIR / "model-pulls" + temp_dir.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + "w", + dir=temp_dir, + prefix="pull-", + suffix=".manifest", + encoding="utf-8", + delete=False, + ) as handle: + handle.write(manifest_line + "\n") + manifest_path = Path(handle.name) + env = os.environ.copy() + env["MODELS_MANIFEST"] = str(manifest_path) + env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") + return ["/opt/pilot/get-models.sh", "pull-all"], env, [manifest_path] + + +def _run_model_pull_job( + job: ModelPullJob, + cmd: list[str], + env: Optional[dict[str, str]] = None, + cleanup_paths: Optional[list[Path]] = None, +) -> None: + merged_env = os.environ.copy() + merged_env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") + if env: + merged_env.update(env) try: - env = os.environ.copy() - env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - env=env, + env=merged_env, bufsize=0, ) job.pid = proc.pid @@ -215,6 +262,7 @@ def _run_model_pull_job(job: ModelPullJob, cmd: list[str]) -> None: job.error = str(e) job.updated_at = time.time() finally: + _cleanup_temp_paths(cleanup_paths or []) with _model_pull_lock: _model_pull_jobs[job.name] = job @@ -575,10 +623,6 @@ class DatasetEntry(BaseModel): path: str -class TrainPilotModelCheckRequest(BaseModel): - toml_path: str = "" - - def _toml_find_first_str(data, key: str) -> Optional[str]: if isinstance(data, dict): v = data.get(key) @@ -1082,13 +1126,31 @@ def _clean_name(name: str) -> str: return cleaned or "dataset" -def _resolve_under_root(root: Path, candidate: Path) -> Path: +def _path_is_within_root(candidate: Path, root: Path) -> bool: root_resolved = os.path.realpath(str(root)) resolved = os.path.realpath(str(candidate)) root_with_sep = os.path.join(root_resolved, "") - if resolved != root_resolved and not resolved.startswith(root_with_sep): + return resolved == root_resolved or resolved.startswith(root_with_sep) + + +def _resolve_under_root(root: Path, candidate: Path) -> Path: + resolved = Path(os.path.realpath(str(candidate))) + if not _path_is_within_root(resolved, root): raise HTTPException(status_code=400, detail="Invalid path") - return Path(resolved) + return resolved + + +def _resolve_local_path_from_roots(raw_value: str, *, field: str, roots: tuple[Path, ...]) -> Path: + raw = (raw_value or "").strip() + if not raw: + raise HTTPException(status_code=400, detail=f"{field} is required") + expanded = os.path.expandvars(os.path.expanduser(raw)) + if not os.path.isabs(expanded): + raise HTTPException(status_code=400, detail=f"{field} must be an absolute local path") + resolved = Path(os.path.realpath(expanded)) + if not any(_path_is_within_root(resolved, root) for root in roots): + raise HTTPException(status_code=400, detail=f"{field} must stay within approved directories") + return resolved def _dataset_dir(name: str) -> Path: @@ -1351,19 +1413,18 @@ def tagpilot_save_item( @app.post("/api/models/{name}/pull") def pull_model(name: str): - models_service.ensure_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR) - entries = models_service.parse_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR) - if not any(entry.name == name for entry in entries): + try: + cmd, env, cleanup_paths = _build_model_pull_command(name) + except KeyError: raise HTTPException(status_code=404, detail="Unknown model") - cmd = ["/opt/pilot/get-models.sh", "pull", name] print(f"[models] pull start name={name} cmd={' '.join(cmd)}", file=sys.stderr) - # Use existing CLI for consistency try: result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + env=env, check=True, ) output = result.stdout or "" @@ -1375,15 +1436,17 @@ def pull_model(name: str): tail = output[-4000:] if len(output) > 4000 else output print(f"[models] pull failed name={name} output_tail={tail!r}", file=sys.stderr) raise HTTPException(status_code=500, detail="Model pull failed") + finally: + _cleanup_temp_paths(cleanup_paths) @app.post("/api/models/{name}/pull/start") def pull_model_start(name: str): """Start a model pull in the background (used by UI for progress updates).""" _cleanup_model_pull_jobs() - models_service.ensure_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR) - entries = models_service.parse_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR) - if not any(e.name == name for e in entries): + try: + cmd, env, cleanup_paths = _build_model_pull_command(name) + except KeyError: raise HTTPException(status_code=404, detail="Unknown model") with _model_pull_lock: @@ -1393,8 +1456,7 @@ def pull_model_start(name: str): job = ModelPullJob(name=name) _model_pull_jobs[name] = job - cmd = ["/opt/pilot/get-models.sh", "pull", name] - threading.Thread(target=_run_model_pull_job, args=(job, cmd), daemon=True).start() + threading.Thread(target=_run_model_pull_job, args=(job, cmd, env, cleanup_paths), daemon=True).start() return _model_pull_job_to_dict(job) @@ -2805,7 +2867,6 @@ class TrainPilotRequest(BaseModel): dataset_name: str output_name: str profile: str = "regular" - toml_path: str = "" def _tp_reader(proc: subprocess.Popen): @@ -2831,21 +2892,6 @@ def _ensure_trainpilot_toml() -> Path: raise HTTPException(status_code=500, detail=f"Bundled TrainPilot TOML not found at {TRAINPILOT_BUNDLED_TOML}") -def _resolve_trainpilot_toml_path(raw_path: str = "") -> Path: - raw = (raw_path or "").strip() - if not raw: - return _ensure_trainpilot_toml() - candidate = Path(raw) - if candidate == TRAINPILOT_BUNDLED_TOML: - candidate = _ensure_trainpilot_toml() - elif not candidate.is_absolute(): - candidate = WORKSPACE_ROOT / candidate - candidate = _resolve_under_root(WORKSPACE_ROOT, candidate) - if candidate.suffix.lower() != ".toml": - raise HTTPException(status_code=400, detail="TrainPilot config must be a TOML file") - return candidate - - @app.post("/api/trainpilot/start") def trainpilot_start(req: TrainPilotRequest): global _tp_proc @@ -2865,9 +2911,7 @@ def trainpilot_start(req: TrainPilotRequest): profile = req.profile.strip() or "regular" if profile not in ("quick_test", "regular", "high_quality"): raise HTTPException(status_code=400, detail="Invalid profile") - toml_path = _resolve_trainpilot_toml_path(req.toml_path) - if not toml_path.exists(): - raise HTTPException(status_code=400, detail=f"TOML not found: {toml_path}") + toml_path = _ensure_trainpilot_toml() # Add debugging info to logs _tp_logs.append(f"=== Starting TrainPilot at {datetime.now().isoformat()} ===") @@ -2928,15 +2972,13 @@ def trainpilot_stop(): @app.post("/api/trainpilot/model-check") -def trainpilot_model_check(req: TrainPilotModelCheckRequest): +def trainpilot_model_check(): """ Parse the selected TrainPilot TOML and check that checkpoint + VAE files exist. If they are missing and can be mapped to a manifest entry, return the model name so the UI can offer to download with progress. """ - toml_path = _resolve_trainpilot_toml_path(req.toml_path) - if not toml_path.exists(): - raise HTTPException(status_code=404, detail=f"TOML not found: {toml_path}") + toml_path = _ensure_trainpilot_toml() try: raw = toml_path.read_bytes() @@ -2968,7 +3010,22 @@ def check_one(kind: str, key: str, value: Optional[str]) -> dict: "model_name": None, "reason": "Not a local file path", } - p = Path(value) + try: + p = _resolve_local_path_from_roots( + value, + field=key, + roots=_TRAINPILOT_LOCAL_PATH_ROOTS, + ) + except HTTPException as exc: + return { + "kind": kind, + "key": key, + "value": value, + "is_local_path": True, + "exists": False, + "model_name": None, + "reason": str(exc.detail), + } exists = p.exists() model_name = None if not exists: diff --git a/apps/Portal/dpipe_api.py b/apps/Portal/dpipe_api.py index e75123f..53fbe59 100644 --- a/apps/Portal/dpipe_api.py +++ b/apps/Portal/dpipe_api.py @@ -1,15 +1,17 @@ import json import os +import re import signal import subprocess import threading +import uuid from collections import deque -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import List, Optional, Union import toml from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field # Paths aligned with the runtime layout WORKSPACE = Path(os.environ.get("WORKSPACE_ROOT", "/workspace")) @@ -19,6 +21,9 @@ CONFIG_DIR = WORKSPACE / "configs" DIFFPIPE_APP_DIR = Path(os.environ.get("DIFFPIPE_APP_DIR", WORKSPACE / "apps" / "diffusion-pipe")) DIFFPIPE_REPO_DIR = Path(os.environ.get("DIFFPIPE_REPO_DIR", "/opt/pilot/repos/diffusion-pipe")) +DPIPE_CONFIG_ROOT = CONFIG_DIR / "dpipe" +DPIPE_OUTPUT_ROOT = OUTPUT_DIR / "dpipe" +DPIPE_RUN_REGISTRY = DPIPE_CONFIG_ROOT / "runs.json" # Deepspeed entrypoint (isolated in the diffusion-pipe venv) DEEPSPEED_BIN = os.environ.get("DEEPSPEED_BIN", "/opt/venvs/diffpipe/bin/deepspeed") @@ -39,7 +44,7 @@ def _ensure_dirs(): - for p in [MODEL_DIR, BASE_DATASET_DIR, OUTPUT_DIR, CONFIG_DIR]: + for p in [MODEL_DIR, BASE_DATASET_DIR, OUTPUT_DIR, CONFIG_DIR, DPIPE_CONFIG_ROOT, DPIPE_OUTPUT_ROOT]: p.mkdir(parents=True, exist_ok=True) @@ -109,6 +114,77 @@ def _resolve_local_path(raw_value: str, *, field: str) -> Path: return Path(resolved) +def _clean_name(value: str, *, default: str) -> str: + cleaned = re.sub(r"[^A-Za-z0-9_-]+", "_", value or "").strip("_-") + return cleaned or default + + +def _normalize_dataset_name(raw_value: str) -> str: + raw = (raw_value or "").strip() + if not raw: + raise HTTPException(status_code=400, detail="dataset_name is required") + leaf = PurePosixPath(raw.replace("\\", "/")).name + normalized = leaf[2:] if leaf.startswith("1_") else leaf + return f"1_{_clean_name(normalized, default='dataset')}" + + +def _resolve_dataset_dir(dataset_name: str) -> Path: + wanted = _normalize_dataset_name(dataset_name) + try: + with os.scandir(BASE_DATASET_DIR) as it: + for entry in it: + try: + if entry.is_dir(follow_symlinks=False) and entry.name == wanted: + return Path(entry.path).resolve() + except Exception: + continue + except FileNotFoundError: + pass + raise HTTPException(status_code=404, detail="dataset not found") + + +def _load_run_registry() -> dict[str, dict[str, str]]: + if not DPIPE_RUN_REGISTRY.exists(): + return {} + try: + raw = json.loads(DPIPE_RUN_REGISTRY.read_text(encoding="utf-8")) + if isinstance(raw, dict): + return { + str(key): value + for key, value in raw.items() + if isinstance(key, str) and isinstance(value, dict) + } + except Exception: + pass + return {} + + +def _save_run_registry(registry: dict[str, dict[str, str]]) -> None: + DPIPE_RUN_REGISTRY.parent.mkdir(parents=True, exist_ok=True) + DPIPE_RUN_REGISTRY.write_text( + json.dumps(registry, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +def _managed_run_dirs(run_name: str, dataset_name: str) -> tuple[str, Path, Path]: + safe_run_name = _clean_name(run_name or dataset_name, default="dpipe_run") + registry = _load_run_registry() + record = registry.get(safe_run_name) + run_id = "" + if record: + run_id = _clean_name(str(record.get("run_id", "")), default="") + if not run_id: + run_id = f"run-{uuid.uuid4().hex[:12]}" + registry[safe_run_name] = { + "dataset_name": dataset_name, + "run_id": run_id, + "run_name": safe_run_name, + } + _save_run_registry(registry) + return safe_run_name, DPIPE_CONFIG_ROOT / run_id, DPIPE_OUTPUT_ROOT / run_id + + def _resolve_deepspeed_bin() -> Path: binary = _resolve_local_path(DEEPSPEED_BIN, field="DEEPSPEED_BIN") if not binary.exists() or not binary.is_file(): @@ -246,9 +322,8 @@ def create_training_config( class TrainRequest(BaseModel): - dataset_path: str - config_dir: str - output_dir: str + dataset_name: str + run_name: str = "" epochs: int = 1000 batch_size: int = 1 lr: float = Field(2e-5, alias="learning_rate") @@ -299,12 +374,22 @@ class TrainValidateRequest(BaseModel): llm_path: str clip_path: str +def _parse_json_list(value: Union[str, list], *, field: str) -> list: + if isinstance(value, list): + return value + try: + parsed = json.loads(value) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Invalid JSON for {field}: {exc}") + if not isinstance(parsed, list): + raise HTTPException(status_code=400, detail=f"{field} must be a JSON list") + return parsed -def _normalize_path(value: str) -> Path: - raw = (value or "").strip() - if not raw: - return Path("") - return Path(os.path.expandvars(os.path.expanduser(raw))) + +def _parse_optional_json_list(value: Optional[str], *, field: str) -> Optional[list]: + if value is None or str(value).strip() == "": + return None + return _parse_json_list(value, field=field) @router.post("/train/validate") @@ -322,50 +407,6 @@ def validate_training_paths(req: TrainValidateRequest): missing.append({"field": key, "path": value}) return {"ok": len(missing) == 0, "missing": missing} - @validator("betas") - def _parse_betas(cls, v): - if isinstance(v, list): - return v - try: - parsed = json.loads(v) - if not isinstance(parsed, list): - raise ValueError - return parsed - except Exception: - raise ValueError("betas must be a JSON list, e.g. [0.9, 0.99]") - - @validator("resolutions_input") - def _parse_resolutions(cls, v): - try: - parsed = json.loads(v) - if not isinstance(parsed, list): - raise ValueError - return v - except Exception: - raise ValueError("resolutions_input must be JSON, e.g. [512] or [[512,512]]") - - @validator("frame_buckets") - def _parse_frames(cls, v): - try: - parsed = json.loads(v) - if not isinstance(parsed, list): - raise ValueError - return v - except Exception: - raise ValueError("frame_buckets must be JSON list, e.g. [1,33]") - - @validator("ar_buckets") - def _parse_ar(cls, v): - if not v: - return "" - try: - parsed = json.loads(v) - if not isinstance(parsed, list): - raise ValueError - return v - except Exception: - raise ValueError("ar_buckets must be JSON list or empty string") - def _ensure_single_run(): with _proc_lock: @@ -378,11 +419,8 @@ def start_training(req: TrainRequest): _ensure_dirs() _ensure_single_run() - ds_path = _resolve_under_root(req.dataset_path, root=BASE_DATASET_DIR, field="dataset_path") - cfg_dir = _resolve_under_root(req.config_dir, root=CONFIG_DIR, field="config_dir") - out_dir = _resolve_under_root(req.output_dir, root=OUTPUT_DIR, field="output_dir") - if not ds_path.exists(): - raise HTTPException(status_code=400, detail="dataset_path does not exist") + ds_path = _resolve_dataset_dir(req.dataset_name) + run_name, cfg_dir, out_dir = _managed_run_dirs(req.run_name, ds_path.name) deepspeed_bin = _resolve_deepspeed_bin() run_dir = _resolve_diffpipe_dir() if not run_dir.exists(): @@ -390,11 +428,12 @@ def start_training(req: TrainRequest): if not (run_dir / "train.py").exists(): raise HTTPException(status_code=500, detail=f"train.py not found in: {run_dir}") try: - resolutions = json.loads(req.resolutions_input) - frames = json.loads(req.frame_buckets) - arb = json.loads(req.ar_buckets) if req.ar_buckets else None - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON in resolutions/frame/ar buckets: {e}") + resolutions = _parse_json_list(req.resolutions_input, field="resolutions_input") + frames = _parse_json_list(req.frame_buckets, field="frame_buckets") + arb = _parse_optional_json_list(req.ar_buckets, field="ar_buckets") + betas = _parse_json_list(req.betas, field="betas") + except HTTPException: + raise dataset_cfg = create_dataset_config( ds_path, @@ -438,7 +477,7 @@ def start_training(req: TrainRequest): only_double_blocks=req.only_double_blocks, optimizer_type=req.optimizer_type, lr=req.lr, - betas=req._parse_betas(req.betas), + betas=betas, weight_decay=req.weight_decay, eps=req.eps, enable_wandb=req.enable_wandb, @@ -476,7 +515,14 @@ def start_training(req: TrainRequest): _procs[pid] = proc _deque_for(pid).clear() threading.Thread(target=_read_stream, args=(proc, pid), daemon=True).start() - return {"status": "started", "pid": pid, "config": str(training_cfg)} + return { + "status": "started", + "pid": pid, + "config": str(training_cfg), + "dataset_name": ds_path.name, + "output_dir": str(out_dir), + "run_name": run_name, + } @router.post("/train/stop") diff --git a/apps/Portal/services/models.py b/apps/Portal/services/models.py index 1b4b530..516cb82 100644 --- a/apps/Portal/services/models.py +++ b/apps/Portal/services/models.py @@ -209,6 +209,27 @@ def parse_size(raw: str) -> Optional[int]: return entries +def manifest_line_for_name( + name: str, + manifest_path: Path, + default_manifest_path: Path, + models_dir: Path, + config_dir: Path, +) -> str: + ensure_manifest(manifest_path, default_manifest_path, models_dir, config_dir) + if not manifest_path.exists(): + raise KeyError("Unknown model") + with manifest_path.open() as f: + for raw in f: + line = raw.strip() + if not line or line.startswith("#"): + continue + parts = line.split("|") + if parts and parts[0] == name: + return line + raise KeyError("Unknown model") + + def delete_model(name: str, manifest_path: Path, models_dir: Path) -> int: line = None with manifest_path.open() as f: diff --git a/apps/Portal/static/js/dpipe.js b/apps/Portal/static/js/dpipe.js index 47db872..5f8f12e 100644 --- a/apps/Portal/static/js/dpipe.js +++ b/apps/Portal/static/js/dpipe.js @@ -5,8 +5,7 @@ const DP_SENSITIVE_FIELDS = new Set([ ]); const DP_FIELDS = [ "dp-dataset", - "dp-config", - "dp-output", + "dp-run", "dp-transformer", "dp-vae", "dp-llm", @@ -47,6 +46,7 @@ window.initDpipe = function () { if (status) status.textContent = ""; loadDpipeSettings(); bindDpipeSettings(); + loadDpipeDatasets(); const wandb = document.getElementById("dp-enable-wandb"); const fields = ["dp-wandb-name","dp-wandb-proj","dp-wandb-key"].map(id => document.getElementById(id)); if (wandb && !wandb.dataset.bound) { @@ -89,9 +89,8 @@ window.startDpipe = async function () { return; } const payload = { - dataset_path: val("dp-dataset"), - config_dir: val("dp-config"), - output_dir: val("dp-output"), + dataset_name: val("dp-dataset"), + run_name: val("dp-run"), transformer_path: modelPaths.transformer_path, vae_path: modelPaths.vae_path, llm_path: modelPaths.llm_path, @@ -154,19 +153,21 @@ function bindDpipeSettings() { const el = document.getElementById(id); if (!el || el.dataset.bound) return; el.dataset.bound = "1"; - const handler = () => saveDpipeSettings(); + const handler = () => { + if (id === "dp-run" && el.type !== "checkbox") { + const normalized = normalizeDpipeRunName(el.value || ""); + if (normalized !== el.value) el.value = normalized; + el.dataset.autoFilled = "0"; + } + saveDpipeSettings(); + }; el.addEventListener("change", handler); el.addEventListener("input", handler); }); } function loadDpipeSettings() { - let data = {}; - try { - data = JSON.parse(localStorage.getItem(DP_STORAGE_KEY) || "{}"); - } catch (e) { - data = {}; - } + const data = loadDpipeSettingsData(); DP_FIELDS.forEach(id => { const el = document.getElementById(id); if (!el || DP_SENSITIVE_FIELDS.has(id) || !(id in data)) return; @@ -179,6 +180,14 @@ function loadDpipeSettings() { }); } +function loadDpipeSettingsData() { + try { + return JSON.parse(localStorage.getItem(DP_STORAGE_KEY) || "{}"); + } catch (e) { + return {}; + } +} + function saveDpipeSettings() { const data = {}; DP_FIELDS.forEach(id => { @@ -198,6 +207,45 @@ function saveDpipeSettings() { } } +async function loadDpipeDatasets() { + const sel = document.getElementById("dp-dataset"); + const status = document.getElementById("dp-status"); + if (!sel) return; + const saved = loadDpipeSettingsData(); + const savedValue = typeof saved["dp-dataset"] === "string" ? saved["dp-dataset"] : ""; + if (!sel.dataset.boundDataset) { + sel.dataset.boundDataset = "1"; + sel.addEventListener("change", () => { + syncDpipeRunName(); + saveDpipeSettings(); + }); + } + sel.innerHTML = ``; + try { + const data = await fetchJson("/api/datasets"); + sel.innerHTML = ""; + if (!Array.isArray(data) || !data.length) { + sel.innerHTML = ``; + return; + } + data.forEach(d => { + const opt = document.createElement("option"); + opt.value = String(d.name || ""); + opt.textContent = `${d.display || d.name} (${d.images || 0} images)`; + sel.appendChild(opt); + }); + if (savedValue && Array.from(sel.options).some(opt => opt.value === savedValue)) { + sel.value = savedValue; + } else if (!sel.value && sel.options.length) { + sel.selectedIndex = 0; + } + syncDpipeRunName(); + } catch (e) { + sel.innerHTML = ``; + if (status) status.textContent = `Error loading datasets: ${e.message || e}`; + } +} + function startLogPoll() { if (dpLogTimer) return; const poll = async () => { @@ -232,6 +280,24 @@ function normalizeDpipeLines(data) { return []; } +function syncDpipeRunName(force = false) { + const run = document.getElementById("dp-run"); + const dataset = document.getElementById("dp-dataset"); + if (!run || !dataset) return; + const isAutoFilled = run.dataset.autoFilled === "1"; + if (run.value.trim() && !force && !isAutoFilled) return; + const label = dataset.options[dataset.selectedIndex]?.textContent || dataset.value || "dpipe_run"; + const base = label.replace(/\s*\([^)]*\)\s*$/, "").trim(); + run.value = normalizeDpipeRunName(base); + run.dataset.autoFilled = "1"; +} + +function normalizeDpipeRunName(value) { + return String(value || "") + .trim() + .replace(/\s+/g, "_"); +} + function val(id) { return document.getElementById(id)?.value.trim() || ""; } function num(id) { return parseInt(val(id) || "0", 10); } function checked(id) { return !!document.getElementById(id)?.checked; } diff --git a/apps/Portal/static/js/trainpilot.js b/apps/Portal/static/js/trainpilot.js index 8c11b0e..895a1a5 100644 --- a/apps/Portal/static/js/trainpilot.js +++ b/apps/Portal/static/js/trainpilot.js @@ -9,11 +9,7 @@ window.initTrainpilot = function () { function bindTpControls() { const status = document.getElementById("tp-status"); const output = document.getElementById("tp-output"); - const tomlPath = document.getElementById("tp-toml"); if (status) status.textContent = ""; - if (tomlPath && !tomlPath.value) { - tomlPath.value = "/workspace/config/trainpilot/newlora.toml"; - } if (output && !output.dataset.bound) { output.dataset.bound = "1"; output.addEventListener("input", () => { @@ -98,11 +94,9 @@ function clearModelDownloadUI() { text.textContent = ""; } -async function ensureTrainpilotModelsPresent(tomlPath) { +async function ensureTrainpilotModelsPresent() { const check = await fetchJson("/api/trainpilot/model-check", { method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ toml_path: tomlPath || "" }), }); const missing = (check && check.missing) ? check.missing : []; if (!missing.length) return true; @@ -173,7 +167,7 @@ async function loadTpDatasets() { } data.forEach(d => { const label = `${d.display || d.name} (${d.images || 0} images)`; - const val = d.path || d.name; + const val = d.name || ""; const opt = document.createElement("option"); opt.value = val; opt.textContent = label; @@ -213,14 +207,13 @@ window.startTrainPilot = async function () { const outputEl = document.getElementById("tp-output"); const output = normalizeOutputName(outputEl?.value.trim() || ""); const profile = document.getElementById("tp-profile")?.value || "regular"; - const toml = document.getElementById("tp-toml")?.value.trim() || ""; const status = document.getElementById("tp-status"); if (outputEl) outputEl.value = output; updateEpochExample(output); if (status) status.textContent = "Starting..."; try { clearModelDownloadUI(); - const ok = await ensureTrainpilotModelsPresent(toml); + const ok = await ensureTrainpilotModelsPresent(); if (!ok) { if (status) status.textContent = "Canceled."; return; @@ -232,7 +225,6 @@ window.startTrainPilot = async function () { dataset_name: dataset, output_name: output, profile, - toml_path: toml, }), }); if (status) status.textContent = "Running..."; diff --git a/apps/Portal/static/views/dpipe.html b/apps/Portal/static/views/dpipe.html index 76a35f7..6d5b548 100644 --- a/apps/Portal/static/views/dpipe.html +++ b/apps/Portal/static/views/dpipe.html @@ -27,10 +27,17 @@
/workspace/configs/dpipe and outputs under /workspace/outputs/dpipe.
+ /workspace/config/trainpilot/newlora.toml