Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions packages/prime/src/prime_cli/commands/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,16 @@ def create_run(
help="Skip action status check and run even if environment action failed.",
),
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
cluster: Optional[str] = typer.Option(
None,
"--cluster",
help=(
"Pin the dispatch to a specific cluster by name. Overrides the "
"TOML's `cluster_name` field. The backend hard-fails with a 400 "
"error if the cluster is unknown, not allocated, or out of "
"capacity — no silent fallback to another cluster."
),
),
) -> None:
"""Launch a Hosted Training run from a config file.

Expand All @@ -891,6 +901,14 @@ def create_run(
console.print(f"[dim]Loading config from {config_path}[/dim]\n")
cfg = load_config(config_path)

# `--cluster` overrides the TOML's `cluster_name` so a user can
# retarget a single dispatch without editing the config. We don't
# validate the name client-side: the backend's picker is the source
# of truth for which clusters the caller can hit, and it returns a
# clear 400 with the available alternatives when the name is wrong.
if cluster is not None:
cfg.cluster_name = cluster

# Collect secrets from all sources
def warn(msg: str) -> None:
console.print(f"[yellow]Warning:[/yellow] {msg}")
Expand Down
128 changes: 128 additions & 0 deletions packages/prime/tests/test_train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,134 @@ def test_train_init_defaults_to_rl_toml() -> None:
assert Path("rl.toml").exists()


def test_train_help_lists_cluster_flag() -> None:
# The flag has to be discoverable from `--help` so users don't have
# to grep source to find it. Regression guard against future arg
# reorders silently hiding the option.
result = runner.invoke(app, ["train", "--help"], env=TEST_ENV)

assert result.exit_code == 0, result.output
assert "--cluster" in result.output


def test_train_cluster_flag_overrides_config_cluster_name(monkeypatch, tmp_path: Path) -> None:
# CLI `--cluster` should win over `cluster_name = "..."` in the TOML
# so users can retarget a single dispatch without editing the config.
# We don't validate the cluster name client-side: the backend's picker
# is the source of truth — verify here only that the override reaches
# the RLClient payload as `cluster_name`.
captured: dict[str, Any] = {}

def mock_create_run(self: Any, **kwargs: Any) -> Any:
captured["kwargs"] = kwargs

class _Run:
id = "run-1"
status = "QUEUED"
runs_ahead = None

def model_dump(self_inner) -> dict[str, Any]:
return {"id": "run-1", "status": "QUEUED"}

return _Run()

def mock_list_models(self: Any, **kwargs: Any) -> list:
return []

monkeypatch.setattr(
"prime_cli.api.rl.RLClient.create_run",
mock_create_run,
)
monkeypatch.setattr(
"prime_cli.api.rl.RLClient.list_models",
mock_list_models,
)

config_path = tmp_path / "rl.toml"
config_path.write_text(
'model = "Qwen/Qwen3-0.6B"\n'
'cluster_name = "config-cluster"\n'
"\n"
"[[env]]\n"
'id = "reverse-text"\n'
)

result = runner.invoke(
app,
[
"train",
str(config_path),
"--cluster",
"flag-cluster",
"--output",
"json",
"--yes",
"--skip-action-check",
],
env={**TEST_ENV, "PRIME_API_KEY": "test-key"},
)

assert result.exit_code == 0, result.output
assert captured["kwargs"]["cluster_name"] == "flag-cluster"


def test_train_without_cluster_flag_uses_config_cluster_name(monkeypatch, tmp_path: Path) -> None:
# Sanity check the inverse: with no --cluster, the TOML's
# cluster_name is what reaches the backend. Without this we'd never
# know if the override path silently took over the no-override path.
captured: dict[str, Any] = {}

def mock_create_run(self: Any, **kwargs: Any) -> Any:
captured["kwargs"] = kwargs

class _Run:
id = "run-1"
status = "QUEUED"
runs_ahead = None

def model_dump(self_inner) -> dict[str, Any]:
return {"id": "run-1", "status": "QUEUED"}

return _Run()

def mock_list_models(self: Any, **kwargs: Any) -> list:
return []

monkeypatch.setattr(
"prime_cli.api.rl.RLClient.create_run",
mock_create_run,
)
monkeypatch.setattr(
"prime_cli.api.rl.RLClient.list_models",
mock_list_models,
)

config_path = tmp_path / "rl.toml"
config_path.write_text(
'model = "Qwen/Qwen3-0.6B"\n'
'cluster_name = "config-cluster"\n'
"\n"
"[[env]]\n"
'id = "reverse-text"\n'
)

result = runner.invoke(
app,
[
"train",
str(config_path),
"--output",
"json",
"--yes",
"--skip-action-check",
],
env={**TEST_ENV, "PRIME_API_KEY": "test-key"},
)

assert result.exit_code == 0, result.output
assert captured["kwargs"]["cluster_name"] == "config-cluster"


def test_train_request_submits_model_request(monkeypatch) -> None:
captured: dict[str, Any] = {}

Expand Down
Loading