Skip to content

Commit 8aafec7

Browse files
committed
prime train logs: expose -c trainer / -c inference / -c env-server
Backend's /api/v1/rft/runs/{run_id}/logs now accepts component + env_name params (dedicated full-FT). Surface them through the CLI: prime train logs <run_id> -c trainer prime train logs <run_id> -c inference prime train logs <run_id> -c env-server --env <name> Legacy --env <name>/<idx> still routes through the env-server-logs endpoint (shared-RFT pods, cluster_id-backed lookup). Dedicated env-server (slug, no slash) goes through the unified /logs route. Per-rank --pod-index intentionally not exposed yet: the chart's torchrun --local-ranks-filter=0 already collapses in-pod rank fan-out to rank 0 stdout, and Loki's pod-label indexing in this tenant doesn't actually filter the prime-job-* streams — per-pod inspection on multi-node runs is kubectl + the PVC log files for now.
1 parent 4d900e3 commit 8aafec7

4 files changed

Lines changed: 184 additions & 47 deletions

File tree

packages/prime/src/prime_cli/api/rl.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,18 @@ class RLRun(BaseModel):
7474
kind: Optional[str] = Field(None, description="Run kind discriminator")
7575

7676
# Training configuration
77-
rollouts_per_example: int = Field(..., alias="rolloutsPerExample")
78-
seq_len: int = Field(..., alias="seqLen")
79-
max_steps: int = Field(..., alias="maxSteps")
77+
rollouts_per_example: Optional[int] = Field(None, alias="rolloutsPerExample")
78+
seq_len: Optional[int] = Field(None, alias="seqLen")
79+
max_steps: Optional[int] = Field(None, alias="maxSteps")
8080
max_tokens: Optional[int] = Field(None, alias="maxTokens")
81-
batch_size: int = Field(..., alias="batchSize")
81+
batch_size: Optional[int] = Field(None, alias="batchSize")
8282
loss: Optional[str] = "rl"
8383
teacher: Optional[Dict[str, Any]] = Field(
8484
None,
8585
validation_alias=AliasChoices("teacher", "teacherConfig"),
8686
serialization_alias="teacher",
8787
)
88-
base_model: str = Field(..., alias="baseModel")
88+
base_model: Optional[str] = Field(None, alias="baseModel")
8989
environments: List[Dict[str, Any]] = Field(default_factory=list)
9090
run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig")
9191
eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig")
@@ -423,16 +423,25 @@ def get_logs(
423423
regex: bool = False,
424424
level: Optional[str] = None,
425425
since_seconds: Optional[int] = None,
426+
component: Optional[str] = None,
427+
pod_index: int = 0,
428+
env_name: Optional[str] = None,
426429
) -> str:
427-
"""Get orchestrator logs for a Hosted Training run.
430+
"""Get logs for one component of a Hosted Training run.
431+
432+
Defaults to the orchestrator pod. Dedicated full-FT runs additionally
433+
expose `trainer`, `inference`, and `env-server` components.
434+
`pod_index` narrows to a specific replica for multi-node
435+
trainer/inference; `env_name` picks among per-env env-server
436+
StatefulSets when `component='env-server'`.
428437
429438
Optional filters narrow the result via the platform's log search
430439
backend:
431440
- search: substring (or regex if regex=True) line filter
432441
- level: one of ERROR/WARNING/SUCCESS/INFO/DEBUG
433-
- since_seconds: how far back to look (6086400)
442+
- since_seconds: how far back to look (60-86400)
434443
"""
435-
params: Dict[str, object] = {"tail_lines": tail_lines}
444+
params: Dict[str, Any] = {"tail_lines": tail_lines}
436445
if search:
437446
params["search"] = search
438447
if regex:
@@ -441,6 +450,12 @@ def get_logs(
441450
params["level"] = level
442451
if since_seconds is not None:
443452
params["since_seconds"] = since_seconds
453+
if component:
454+
params["component"] = component
455+
if pod_index:
456+
params["pod_index"] = pod_index
457+
if env_name:
458+
params["env_name"] = env_name
444459
try:
445460
response = self.client.get(f"/rft/runs/{run_id}/logs", params=params)
446461
return response.get("logs", "")

packages/prime/src/prime_cli/commands/rl.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -916,10 +916,10 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]:
916916
return {
917917
"id": run.id,
918918
"status": run.status,
919-
"model": run.base_model,
919+
"model": run.base_model or "-",
920920
"environments": envs_display,
921-
"steps": f"{run.max_steps}",
922-
"rollouts": str(run.rollouts_per_example),
921+
"steps": "-" if run.max_steps is None else f"{run.max_steps}",
922+
"rollouts": "-" if run.rollouts_per_example is None else str(run.rollouts_per_example),
923923
"created_at": created_at,
924924
"team_id": run.team_id,
925925
}
@@ -1776,11 +1776,12 @@ def get_run(
17761776
if run.status == "QUEUED" and run.runs_ahead is not None:
17771777
status_text += f" (~{run.runs_ahead} runs ahead)"
17781778
console.print(f" Status: [{status_color}]{status_text}[/{status_color}]")
1779-
console.print(f" Model: [magenta]{run.base_model}[/magenta]")
1779+
console.print(f" Model: [magenta]{formatted['model']}[/magenta]")
17801780
console.print(f" Environments: [green]{formatted['environments']}[/green]")
1781-
console.print(f" Max Steps: {run.max_steps}")
1782-
console.print(f" Batch Size: {run.batch_size}")
1783-
console.print(f" Rollouts per Example: {run.rollouts_per_example}")
1781+
console.print(f" Max Steps: {formatted['steps']}")
1782+
batch_size = "-" if run.batch_size is None else str(run.batch_size)
1783+
console.print(f" Batch Size: {batch_size}")
1784+
console.print(f" Rollouts per Example: {formatted['rollouts']}")
17841785
if run.max_tokens:
17851786
console.print(f" Max Tokens: {run.max_tokens}")
17861787
if run.wandb_project:
@@ -2078,6 +2079,14 @@ def _parse_env_qualifier(env: str) -> tuple[str, int]:
20782079
return env, 0
20792080

20802081

2082+
def _parse_env_qualifier_with_index(env: str) -> tuple[str, int, bool]:
2083+
"""Parse an env qualifier and report whether a numeric suffix was present."""
2084+
name, sep, idx_str = env.rpartition("/")
2085+
if sep and name and idx_str.isdigit():
2086+
return name, int(idx_str), True
2087+
return env, 0, False
2088+
2089+
20812090
@app.command("logs", rich_help_panel="Monitoring")
20822091
def get_logs(
20832092
run_id: str = typer.Argument(..., help="Run ID to get logs for"),
@@ -2086,8 +2095,9 @@ def get_logs(
20862095
"--component",
20872096
"-c",
20882097
help=(
2089-
"Pod to read logs from: 'orchestrator' (default) or 'env-server'. "
2090-
"Inferred from --env when omitted."
2098+
"Pod to read logs from: 'orchestrator' (default), 'trainer', "
2099+
"'inference', or 'env-server'. trainer/inference apply only "
2100+
"to dedicated full-FT runs. Inferred from --env when omitted."
20912101
),
20922102
),
20932103
env: Optional[str] = typer.Option(
@@ -2132,33 +2142,41 @@ def get_logs(
21322142
) -> None:
21332143
"""Get logs for a run.
21342144
2135-
Defaults to the orchestrator pod. Pass ``--env <name>`` to read an
2136-
env-server pod instead — useful when an env-server is crash-looping
2137-
(e.g. ``ModuleNotFoundError``) and the orchestrator has stalled at
2138-
"Starting orchestrator step 0".
2145+
Defaults to the orchestrator pod. Use ``--component`` to pick one of
2146+
``trainer`` / ``inference`` / ``env-server`` (dedicated full-FT only).
2147+
Pass ``--env <name>`` to read an env-server pod by name (shorthand for
2148+
``--component=env-server``).
21392149
21402150
List available pods first with ``prime train components <run_id>``.
21412151
2152+
Per-rank narrowing on multi-replica trainer/inference is not yet
2153+
surfaced here — `--local-ranks-filter=0` in the chart's torchrun
2154+
invocation already dedupes the in-pod rank fan-out, and per-pod
2155+
inspection on multi-node runs requires kubectl + the PVC log files.
2156+
21422157
Examples:
21432158
21442159
prime train logs <run_id>
21452160
prime train logs <run_id> -f
21462161
prime train logs <run_id> --search Backpressure
21472162
prime train logs <run_id> --level ERROR --since 1h
21482163
prime train logs <run_id> --search 'Step \\d+' --regex
2164+
prime train logs <run_id> -c trainer
2165+
prime train logs <run_id> -c inference
21492166
prime train logs <run_id> --env reverse-text
21502167
prime train logs <run_id> --env reverse-text/1 -f
21512168
"""
2169+
valid_components = ("orchestrator", "trainer", "inference", "env-server")
21522170
if component is None:
21532171
component = "env-server" if env is not None else "orchestrator"
2154-
elif component not in ("orchestrator", "env-server"):
2172+
elif component not in valid_components:
21552173
raise typer.BadParameter(
2156-
f"Invalid component '{component}'. Use 'orchestrator' or 'env-server'.",
2174+
f"Invalid component '{component}'. Use one of: {', '.join(valid_components)}.",
21572175
param_hint="--component",
21582176
)
2159-
if component == "orchestrator" and env is not None:
2177+
if env is not None and component != "env-server":
21602178
raise typer.BadParameter(
2161-
"--env applies only to env-server logs. Drop --component=orchestrator or drop --env.",
2179+
f"--env applies only to env-server logs. Drop --component={component} or drop --env.",
21622180
param_hint="--env",
21632181
)
21642182
if component == "env-server" and env is None:
@@ -2189,7 +2207,33 @@ def get_logs(
21892207
api_client = APIClient()
21902208
rl_client = RLClient(api_client)
21912209

2192-
if component == "orchestrator":
2210+
env_name_q, env_index_q, env_has_index_q = (
2211+
_parse_env_qualifier_with_index(env) if env is not None else (None, 0, False)
2212+
)
2213+
2214+
if component == "env-server" and env is not None and env_has_index_q:
2215+
assert env_name_q is not None
2216+
# Legacy shared-RFT env-server (`name/index` qualifier) — go
2217+
# through the dedicated env-server endpoint which uses the
2218+
# cluster_id-backed pod lookup path. Dedicated full-FT
2219+
# env-servers use the unified /logs route with
2220+
# component=env-server + env_name (StatefulSets always run
2221+
# one pod per env, so no index disambiguation needed).
2222+
2223+
def fetch(t: int) -> str:
2224+
return rl_client.get_env_server_logs(
2225+
run_id,
2226+
env_name=env_name_q,
2227+
env_index=env_index_q,
2228+
tail_lines=t,
2229+
search=search,
2230+
regex=regex,
2231+
level=normalized_level,
2232+
since_seconds=since_seconds,
2233+
)
2234+
2235+
label = f"env-server {env}"
2236+
elif component == "orchestrator":
21932237

21942238
def fetch(t: int) -> str:
21952239
return rl_client.get_logs(
@@ -2203,22 +2247,25 @@ def fetch(t: int) -> str:
22032247

22042248
label = "orchestrator"
22052249
else:
2206-
assert env is not None # narrowed by validation above
2207-
env_name, env_index = _parse_env_qualifier(env)
2250+
# trainer / inference / dedicated env-server — unified /logs
2251+
# route. env (no slash) names the dedicated env-server's
2252+
# StatefulSet.
2253+
fetch_component = component
2254+
fetch_env = env if component == "env-server" else None
22082255

22092256
def fetch(t: int) -> str:
2210-
return rl_client.get_env_server_logs(
2257+
return rl_client.get_logs(
22112258
run_id,
2212-
env_name=env_name,
2213-
env_index=env_index,
22142259
tail_lines=t,
22152260
search=search,
22162261
regex=regex,
22172262
level=normalized_level,
22182263
since_seconds=since_seconds,
2264+
component=fetch_component,
2265+
env_name=fetch_env,
22192266
)
22202267

2221-
label = f"env-server {env}"
2268+
label = f"env-server {env}" if component == "env-server" else component
22222269

22232270
_stream_logs(
22242271
fetch_fn=fetch,

packages/prime/tests/test_rl_api.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from prime_cli.api.rl import RLClient
5+
from prime_cli.api.rl import RLClient, RLRun
66

77

88
class FakeAPIClient:
@@ -46,6 +46,27 @@ def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict[str, A
4646
}
4747

4848

49+
def test_run_model_allows_dedicated_full_ft_without_lora_fields() -> None:
50+
run = RLRun.model_validate(
51+
{
52+
"id": "full-ft-run",
53+
"name": "dedicated",
54+
"userId": "user-1",
55+
"status": "RUNNING",
56+
"kind": "DEDICATED_FULL_FT",
57+
"createdAt": "2026-05-17T00:00:00Z",
58+
"updatedAt": "2026-05-17T00:00:00Z",
59+
}
60+
)
61+
62+
assert run.kind == "DEDICATED_FULL_FT"
63+
assert run.rollouts_per_example is None
64+
assert run.seq_len is None
65+
assert run.max_steps is None
66+
assert run.batch_size is None
67+
assert run.base_model is None
68+
69+
4970
def test_get_distributions_preserves_chart_histogram_data() -> None:
5071
api_client = FakeAPIClient()
5172
client = RLClient(api_client) # type: ignore[arg-type]

0 commit comments

Comments
 (0)