Skip to content

Commit feea570

Browse files
committed
prime train logs: expose -c trainer / -c inference / -c env-server
Backend's /api/v1/rft/runs/{run_id}/logs now accepts component + env_name params (dedicated full-FT). Surface them through the CLI: prime train logs <run_id> -c trainer prime train logs <run_id> -c inference prime train logs <run_id> -c env-server --env <name> Legacy --env <name>/<idx> still routes through the env-server-logs endpoint (shared-RFT pods, cluster_id-backed lookup). Dedicated env-server (slug, no slash) goes through the unified /logs route. Per-rank --pod-index intentionally not exposed yet: the chart's torchrun --local-ranks-filter=0 already collapses in-pod rank fan-out to rank 0 stdout, and Loki's pod-label indexing in this tenant doesn't actually filter the prime-job-* streams — per-pod inspection on multi-node runs is kubectl + the PVC log files for now.
1 parent 4d900e3 commit feea570

4 files changed

Lines changed: 187 additions & 53 deletions

File tree

packages/prime/src/prime_cli/api/rl.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,18 @@ class RLRun(BaseModel):
7474
kind: Optional[str] = Field(None, description="Run kind discriminator")
7575

7676
# Training configuration
77-
rollouts_per_example: int = Field(..., alias="rolloutsPerExample")
78-
seq_len: int = Field(..., alias="seqLen")
79-
max_steps: int = Field(..., alias="maxSteps")
77+
rollouts_per_example: Optional[int] = Field(None, alias="rolloutsPerExample")
78+
seq_len: Optional[int] = Field(None, alias="seqLen")
79+
max_steps: Optional[int] = Field(None, alias="maxSteps")
8080
max_tokens: Optional[int] = Field(None, alias="maxTokens")
81-
batch_size: int = Field(..., alias="batchSize")
81+
batch_size: Optional[int] = Field(None, alias="batchSize")
8282
loss: Optional[str] = "rl"
8383
teacher: Optional[Dict[str, Any]] = Field(
8484
None,
8585
validation_alias=AliasChoices("teacher", "teacherConfig"),
8686
serialization_alias="teacher",
8787
)
88-
base_model: str = Field(..., alias="baseModel")
88+
base_model: Optional[str] = Field(None, alias="baseModel")
8989
environments: List[Dict[str, Any]] = Field(default_factory=list)
9090
run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig")
9191
eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig")
@@ -423,16 +423,25 @@ def get_logs(
423423
regex: bool = False,
424424
level: Optional[str] = None,
425425
since_seconds: Optional[int] = None,
426+
component: Optional[str] = None,
427+
pod_index: int = 0,
428+
env_name: Optional[str] = None,
426429
) -> str:
427-
"""Get orchestrator logs for a Hosted Training run.
430+
"""Get logs for one component of a Hosted Training run.
431+
432+
Defaults to the orchestrator pod. Dedicated full-FT runs additionally
433+
expose `trainer`, `inference`, and `env-server` components.
434+
`pod_index` narrows to a specific replica for multi-node
435+
trainer/inference; `env_name` picks among per-env env-server
436+
StatefulSets when `component='env-server'`.
428437
429438
Optional filters narrow the result via the platform's log search
430439
backend:
431440
- search: substring (or regex if regex=True) line filter
432441
- level: one of ERROR/WARNING/SUCCESS/INFO/DEBUG
433-
- since_seconds: how far back to look (6086400)
442+
- since_seconds: how far back to look (60-86400)
434443
"""
435-
params: Dict[str, object] = {"tail_lines": tail_lines}
444+
params: Dict[str, Any] = {"tail_lines": tail_lines}
436445
if search:
437446
params["search"] = search
438447
if regex:
@@ -441,6 +450,12 @@ def get_logs(
441450
params["level"] = level
442451
if since_seconds is not None:
443452
params["since_seconds"] = since_seconds
453+
if component:
454+
params["component"] = component
455+
if pod_index:
456+
params["pod_index"] = pod_index
457+
if env_name:
458+
params["env_name"] = env_name
444459
try:
445460
response = self.client.get(f"/rft/runs/{run_id}/logs", params=params)
446461
return response.get("logs", "")

packages/prime/src/prime_cli/commands/rl.py

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -916,10 +916,10 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]:
916916
return {
917917
"id": run.id,
918918
"status": run.status,
919-
"model": run.base_model,
919+
"model": run.base_model or "-",
920920
"environments": envs_display,
921-
"steps": f"{run.max_steps}",
922-
"rollouts": str(run.rollouts_per_example),
921+
"steps": "-" if run.max_steps is None else f"{run.max_steps}",
922+
"rollouts": "-" if run.rollouts_per_example is None else str(run.rollouts_per_example),
923923
"created_at": created_at,
924924
"team_id": run.team_id,
925925
}
@@ -1776,11 +1776,12 @@ def get_run(
17761776
if run.status == "QUEUED" and run.runs_ahead is not None:
17771777
status_text += f" (~{run.runs_ahead} runs ahead)"
17781778
console.print(f" Status: [{status_color}]{status_text}[/{status_color}]")
1779-
console.print(f" Model: [magenta]{run.base_model}[/magenta]")
1779+
console.print(f" Model: [magenta]{formatted['model']}[/magenta]")
17801780
console.print(f" Environments: [green]{formatted['environments']}[/green]")
1781-
console.print(f" Max Steps: {run.max_steps}")
1782-
console.print(f" Batch Size: {run.batch_size}")
1783-
console.print(f" Rollouts per Example: {run.rollouts_per_example}")
1781+
console.print(f" Max Steps: {formatted['steps']}")
1782+
batch_size = "-" if run.batch_size is None else str(run.batch_size)
1783+
console.print(f" Batch Size: {batch_size}")
1784+
console.print(f" Rollouts per Example: {formatted['rollouts']}")
17841785
if run.max_tokens:
17851786
console.print(f" Max Tokens: {run.max_tokens}")
17861787
if run.wandb_project:
@@ -1846,12 +1847,9 @@ def delete_run(
18461847
# Try the hosted full-FT delete endpoint first. The backend's kind
18471848
# gate 404s for non-DEDICATED_FULL_FT runs, so a 404 here means
18481849
# "not a hosted run" and we fall back to the LoRA-shared path.
1849-
# This avoids the prior approach of pre-fetching via rl_client.get_run
1850-
# for the discriminator — which fails for DEDICATED_FULL_FT runs
1851-
# whose row doesn't carry the LoRA-required RLRun fields
1852-
# (rollouts_per_example, seq_len, max_steps, batch_size, base_model).
1853-
# Pydantic ValidationError on those would mask the actual run kind
1854-
# and silently route to the wrong endpoint.
1850+
# This avoids relying on list/get discriminator shape before delete:
1851+
# the delete endpoint owns the run-kind decision, and the CLI only
1852+
# falls back when that endpoint says the row is not dedicated full-FT.
18551853
from ..api.training import HostedTrainingClient
18561854

18571855
rl_client = RLClient(api_client)
@@ -2078,6 +2076,14 @@ def _parse_env_qualifier(env: str) -> tuple[str, int]:
20782076
return env, 0
20792077

20802078

2079+
def _parse_env_qualifier_with_index(env: str) -> tuple[str, int, bool]:
2080+
"""Parse an env qualifier and report whether a numeric suffix was present."""
2081+
name, sep, idx_str = env.rpartition("/")
2082+
if sep and name and idx_str.isdigit():
2083+
return name, int(idx_str), True
2084+
return env, 0, False
2085+
2086+
20812087
@app.command("logs", rich_help_panel="Monitoring")
20822088
def get_logs(
20832089
run_id: str = typer.Argument(..., help="Run ID to get logs for"),
@@ -2086,8 +2092,9 @@ def get_logs(
20862092
"--component",
20872093
"-c",
20882094
help=(
2089-
"Pod to read logs from: 'orchestrator' (default) or 'env-server'. "
2090-
"Inferred from --env when omitted."
2095+
"Pod to read logs from: 'orchestrator' (default), 'trainer', "
2096+
"'inference', or 'env-server'. trainer/inference apply only "
2097+
"to dedicated full-FT runs. Inferred from --env when omitted."
20912098
),
20922099
),
20932100
env: Optional[str] = typer.Option(
@@ -2132,33 +2139,41 @@ def get_logs(
21322139
) -> None:
21332140
"""Get logs for a run.
21342141
2135-
Defaults to the orchestrator pod. Pass ``--env <name>`` to read an
2136-
env-server pod instead — useful when an env-server is crash-looping
2137-
(e.g. ``ModuleNotFoundError``) and the orchestrator has stalled at
2138-
"Starting orchestrator step 0".
2142+
Defaults to the orchestrator pod. Use ``--component`` to pick one of
2143+
``trainer`` / ``inference`` / ``env-server`` (dedicated full-FT only).
2144+
Pass ``--env <name>`` to read an env-server pod by name (shorthand for
2145+
``--component=env-server``).
21392146
21402147
List available pods first with ``prime train components <run_id>``.
21412148
2149+
Per-rank narrowing on multi-replica trainer/inference is not yet
2150+
surfaced here — `--local-ranks-filter=0` in the chart's torchrun
2151+
invocation already dedupes the in-pod rank fan-out, and per-pod
2152+
inspection on multi-node runs requires kubectl + the PVC log files.
2153+
21422154
Examples:
21432155
21442156
prime train logs <run_id>
21452157
prime train logs <run_id> -f
21462158
prime train logs <run_id> --search Backpressure
21472159
prime train logs <run_id> --level ERROR --since 1h
21482160
prime train logs <run_id> --search 'Step \\d+' --regex
2161+
prime train logs <run_id> -c trainer
2162+
prime train logs <run_id> -c inference
21492163
prime train logs <run_id> --env reverse-text
21502164
prime train logs <run_id> --env reverse-text/1 -f
21512165
"""
2166+
valid_components = ("orchestrator", "trainer", "inference", "env-server")
21522167
if component is None:
21532168
component = "env-server" if env is not None else "orchestrator"
2154-
elif component not in ("orchestrator", "env-server"):
2169+
elif component not in valid_components:
21552170
raise typer.BadParameter(
2156-
f"Invalid component '{component}'. Use 'orchestrator' or 'env-server'.",
2171+
f"Invalid component '{component}'. Use one of: {', '.join(valid_components)}.",
21572172
param_hint="--component",
21582173
)
2159-
if component == "orchestrator" and env is not None:
2174+
if env is not None and component != "env-server":
21602175
raise typer.BadParameter(
2161-
"--env applies only to env-server logs. Drop --component=orchestrator or drop --env.",
2176+
f"--env applies only to env-server logs. Drop --component={component} or drop --env.",
21622177
param_hint="--env",
21632178
)
21642179
if component == "env-server" and env is None:
@@ -2189,7 +2204,33 @@ def get_logs(
21892204
api_client = APIClient()
21902205
rl_client = RLClient(api_client)
21912206

2192-
if component == "orchestrator":
2207+
env_name_q, env_index_q, env_has_index_q = (
2208+
_parse_env_qualifier_with_index(env) if env is not None else (None, 0, False)
2209+
)
2210+
2211+
if component == "env-server" and env is not None and env_has_index_q:
2212+
assert env_name_q is not None
2213+
# Legacy shared-RFT env-server (`name/index` qualifier) — go
2214+
# through the dedicated env-server endpoint which uses the
2215+
# cluster_id-backed pod lookup path. Dedicated full-FT
2216+
# env-servers use the unified /logs route with
2217+
# component=env-server + env_name (StatefulSets always run
2218+
# one pod per env, so no index disambiguation needed).
2219+
2220+
def fetch(t: int) -> str:
2221+
return rl_client.get_env_server_logs(
2222+
run_id,
2223+
env_name=env_name_q,
2224+
env_index=env_index_q,
2225+
tail_lines=t,
2226+
search=search,
2227+
regex=regex,
2228+
level=normalized_level,
2229+
since_seconds=since_seconds,
2230+
)
2231+
2232+
label = f"env-server {env}"
2233+
elif component == "orchestrator":
21932234

21942235
def fetch(t: int) -> str:
21952236
return rl_client.get_logs(
@@ -2203,22 +2244,25 @@ def fetch(t: int) -> str:
22032244

22042245
label = "orchestrator"
22052246
else:
2206-
assert env is not None # narrowed by validation above
2207-
env_name, env_index = _parse_env_qualifier(env)
2247+
# trainer / inference / dedicated env-server — unified /logs
2248+
# route. env (no slash) names the dedicated env-server's
2249+
# StatefulSet.
2250+
fetch_component = component
2251+
fetch_env = env if component == "env-server" else None
22082252

22092253
def fetch(t: int) -> str:
2210-
return rl_client.get_env_server_logs(
2254+
return rl_client.get_logs(
22112255
run_id,
2212-
env_name=env_name,
2213-
env_index=env_index,
22142256
tail_lines=t,
22152257
search=search,
22162258
regex=regex,
22172259
level=normalized_level,
22182260
since_seconds=since_seconds,
2261+
component=fetch_component,
2262+
env_name=fetch_env,
22192263
)
22202264

2221-
label = f"env-server {env}"
2265+
label = f"env-server {env}" if component == "env-server" else component
22222266

22232267
_stream_logs(
22242268
fetch_fn=fetch,

packages/prime/tests/test_rl_api.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from prime_cli.api.rl import RLClient
5+
from prime_cli.api.rl import RLClient, RLRun
66

77

88
class FakeAPIClient:
@@ -46,6 +46,27 @@ def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict[str, A
4646
}
4747

4848

49+
def test_run_model_allows_dedicated_full_ft_without_lora_fields() -> None:
50+
run = RLRun.model_validate(
51+
{
52+
"id": "full-ft-run",
53+
"name": "dedicated",
54+
"userId": "user-1",
55+
"status": "RUNNING",
56+
"kind": "DEDICATED_FULL_FT",
57+
"createdAt": "2026-05-17T00:00:00Z",
58+
"updatedAt": "2026-05-17T00:00:00Z",
59+
}
60+
)
61+
62+
assert run.kind == "DEDICATED_FULL_FT"
63+
assert run.rollouts_per_example is None
64+
assert run.seq_len is None
65+
assert run.max_steps is None
66+
assert run.batch_size is None
67+
assert run.base_model is None
68+
69+
4970
def test_get_distributions_preserves_chart_histogram_data() -> None:
5071
api_client = FakeAPIClient()
5172
client = RLClient(api_client) # type: ignore[arg-type]

0 commit comments

Comments
 (0)