Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions packages/prime/src/prime_cli/api/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,24 @@ class RLRun(BaseModel):
team_id: Optional[str] = Field(None, alias="teamId")
cluster_id: Optional[str] = Field(None, alias="rftClusterId")
status: str = Field(..., description="Run status")
# Discriminator: SHARED_RFT_HOSTED (LoRA) | DEDICATED_FULL_FT (own
# helm release on a PrimeCluster) | EXTERNAL (CLI-side prime-rl).
# Optional for backward-compat with older API versions.
kind: Optional[str] = Field(None, description="Run kind discriminator")
Comment thread
JannikSt marked this conversation as resolved.

# Training configuration
rollouts_per_example: int = Field(..., alias="rolloutsPerExample")
seq_len: int = Field(..., alias="seqLen")
max_steps: int = Field(..., alias="maxSteps")
rollouts_per_example: Optional[int] = Field(None, alias="rolloutsPerExample")
seq_len: Optional[int] = Field(None, alias="seqLen")
max_steps: Optional[int] = Field(None, alias="maxSteps")
max_tokens: Optional[int] = Field(None, alias="maxTokens")
batch_size: int = Field(..., alias="batchSize")
batch_size: Optional[int] = Field(None, alias="batchSize")
loss: Optional[str] = "rl"
teacher: Optional[Dict[str, Any]] = Field(
None,
validation_alias=AliasChoices("teacher", "teacherConfig"),
serialization_alias="teacher",
)
base_model: str = Field(..., alias="baseModel")
base_model: Optional[str] = Field(None, alias="baseModel")
environments: List[Dict[str, Any]] = Field(default_factory=list)
run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig")
eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig")
Expand Down Expand Up @@ -419,16 +423,25 @@ def get_logs(
regex: bool = False,
level: Optional[str] = None,
since_seconds: Optional[int] = None,
component: Optional[str] = None,
pod_index: Optional[int] = None,
env_name: Optional[str] = None,
) -> str:
"""Get orchestrator logs for a Hosted Training run.
"""Get logs for one component of a Hosted Training run.

Defaults to the orchestrator pod. Dedicated full-FT runs additionally
expose `trainer`, `inference`, and `env-server` components.
`pod_index` narrows to a specific replica for multi-node
trainer/inference; `env_name` picks among per-env env-server
StatefulSets when `component='env-server'`.

Optional filters narrow the result via the platform's log search
backend:
- search: substring (or regex if regex=True) line filter
- level: one of ERROR/WARNING/SUCCESS/INFO/DEBUG
- since_seconds: how far back to look (6086400)
- since_seconds: how far back to look (60-86400)
"""
params: Dict[str, object] = {"tail_lines": tail_lines}
params: Dict[str, Any] = {"tail_lines": tail_lines}
if search:
params["search"] = search
if regex:
Expand All @@ -437,6 +450,12 @@ def get_logs(
params["level"] = level
if since_seconds is not None:
params["since_seconds"] = since_seconds
if component:
params["component"] = component
if pod_index is not None:
params["pod_index"] = pod_index
Comment thread
cursor[bot] marked this conversation as resolved.
if env_name:
params["env_name"] = env_name
try:
response = self.client.get(f"/rft/runs/{run_id}/logs", params=params)
return response.get("logs", "")
Expand Down
93 changes: 93 additions & 0 deletions packages/prime/src/prime_cli/api/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Hosted full-FT training API client (POST/DELETE /v1/training/runs).

Sibling to api/rl.py — that's the LoRA-shared path. This client speaks to
the dedicated full-parameter prime-rl endpoint where each run gets its
own helm release on a registered PrimeCluster. Auth is the standard API
token; admin role is gated server-side.
"""

from typing import Any, Dict, Optional

from pydantic import BaseModel, ConfigDict, Field

from prime_cli.core import APIClient


class HostedTrainingRunResponse(BaseModel):
"""Response from POST /v1/training/runs."""

run_id: str = Field(..., alias="runId")
job_id: str = Field(..., alias="jobId")
token_value: str = Field(..., alias="tokenValue")

model_config = ConfigDict(populate_by_name=True)


class HostedTrainingClient:
"""Client for the hosted full-FT training endpoint."""

def __init__(self, client: APIClient) -> None:
self.client = client

def create_run(self, payload: Dict[str, Any]) -> HostedTrainingRunResponse:
"""POST /v1/training/runs. Backend mints a per-run API token and
kicks off the helm install asynchronously; returns immediately with
identifiers.

Lets typed APIError subclasses (NotFoundError, UnauthorizedError,
…) propagate so callers can branch by exception class instead of
string-matching the message.
"""
response = self.client.post("/training/runs", json=payload)
return HostedTrainingRunResponse.model_validate(response)

def delete_run(self, run_id: str) -> Dict[str, Any]:
"""DELETE /v1/training/runs/{run_id}. Idempotent: helm uninstall +
namespace delete + RFTRun soft-delete. Returns the wire payload
(typically {runId, deleted}). Re-raises typed APIError subclasses
(NotFoundError on 404, etc.) so callers can branch by exception
class — `prime train delete` uses NotFoundError as the 'not a
hosted run, try LoRA' fallback signal.
"""
response = self.client.request("DELETE", f"/training/runs/{run_id}")
return response if isinstance(response, dict) else {"runId": run_id}


def build_payload_from_toml(
cfg: Dict[str, Any],
*,
name: Optional[str] = None,
team_id: Optional[str] = None,
image_tag: Optional[str] = None,
wandb_api_key: Optional[str] = None,
hf_token: Optional[str] = None,
) -> Dict[str, Any]:
"""Build the /v1/training/runs payload from a prime-rl-style TOML dict.

Ships the *whole* TOML as `config` so the backend can split it
per-component (trainer / orchestrator / inference) and bake each
into the corresponding pod's startup command. Anything outside the
handful of platform-authoritative overlays (chart-side scrape ports,
monitor URL, secret name) flows through unchanged — same e2e
behaviour as `uv run rl @ rl.toml`.

What stays out of `config`:
- secrets (wandb / hf): materialised into a per-run k8s Secret,
- run name: lives on the platform's RFTRun row, not the TOML,
- team_id: links the RFTRun to a team for billing/access scoping,
- image_tag: chart-level (which prime-rl image to pull).

Cluster targeting is backend-side (auto-pick first uncordoned).
"""
payload: Dict[str, Any] = {"config": cfg}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Strip CLI-only env file keys before dispatch

When a full-finetune TOML uses the supported env_file/env_files keys, _dispatch_full_finetune_run reads those files into secrets but then passes the original raw_cfg here unchanged, so the backend receives config.env_file(s) as part of the prime-rl config. Those local paths will not exist in the hosted pod (and can also be rejected as unknown prime-rl settings), so runs that rely on config-level env files can fail even though the secrets were already extracted; copy the config and remove these CLI-only keys before putting it under config.

Useful? React with 👍 / 👎.

if name:
payload["name"] = name
if team_id:
payload["teamId"] = team_id
if image_tag:
payload["imageTag"] = image_tag
if wandb_api_key:
payload["wandbApiKey"] = wandb_api_key
if hf_token:
payload["hfToken"] = hf_token
return payload
2 changes: 2 additions & 0 deletions packages/prime/src/prime_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
APIError,
APITimeoutError,
AsyncAPIClient,
NotFoundError,
PaymentRequiredError,
UnauthorizedError,
ValidationError,
Expand All @@ -15,6 +16,7 @@
"AsyncAPIClient",
"APIError",
"APITimeoutError",
"NotFoundError",
"PaymentRequiredError",
"UnauthorizedError",
"ValidationError",
Expand Down
Loading
Loading