-
Notifications
You must be signed in to change notification settings - Fork 44
feature: prime train routes full_finetune TOMLs to hosted endpoint #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
932690d
aff9474
d0d27c7
eac6f84
d254e18
3bc37f9
91d2a74
12e846a
1bc2610
4d900e3
feea570
e5c0e2f
a41e4b8
bc982f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| """Hosted full-FT training API client (POST/DELETE /v1/training/runs). | ||
|
|
||
| Sibling to api/rl.py — that's the LoRA-shared path. This client speaks to | ||
| the dedicated full-parameter prime-rl endpoint where each run gets its | ||
| own helm release on a registered PrimeCluster. Auth is the standard API | ||
| token; admin role is gated server-side. | ||
| """ | ||
|
|
||
| from typing import Any, Dict, Optional | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
| from prime_cli.core import APIClient | ||
|
|
||
|
|
||
| class HostedTrainingRunResponse(BaseModel): | ||
| """Response from POST /v1/training/runs.""" | ||
|
|
||
| run_id: str = Field(..., alias="runId") | ||
| job_id: str = Field(..., alias="jobId") | ||
| token_value: str = Field(..., alias="tokenValue") | ||
|
|
||
| model_config = ConfigDict(populate_by_name=True) | ||
|
|
||
|
|
||
| class HostedTrainingClient: | ||
| """Client for the hosted full-FT training endpoint.""" | ||
|
|
||
| def __init__(self, client: APIClient) -> None: | ||
| self.client = client | ||
|
|
||
| def create_run(self, payload: Dict[str, Any]) -> HostedTrainingRunResponse: | ||
| """POST /v1/training/runs. Backend mints a per-run API token and | ||
| kicks off the helm install asynchronously; returns immediately with | ||
| identifiers. | ||
|
|
||
| Lets typed APIError subclasses (NotFoundError, UnauthorizedError, | ||
| …) propagate so callers can branch by exception class instead of | ||
| string-matching the message. | ||
| """ | ||
| response = self.client.post("/training/runs", json=payload) | ||
| return HostedTrainingRunResponse.model_validate(response) | ||
|
|
||
| def delete_run(self, run_id: str) -> Dict[str, Any]: | ||
| """DELETE /v1/training/runs/{run_id}. Idempotent: helm uninstall + | ||
| namespace delete + RFTRun soft-delete. Returns the wire payload | ||
| (typically {runId, deleted}). Re-raises typed APIError subclasses | ||
| (NotFoundError on 404, etc.) so callers can branch by exception | ||
| class — `prime train delete` uses NotFoundError as the 'not a | ||
| hosted run, try LoRA' fallback signal. | ||
| """ | ||
| response = self.client.request("DELETE", f"/training/runs/{run_id}") | ||
| return response if isinstance(response, dict) else {"runId": run_id} | ||
|
|
||
|
|
||
| def build_payload_from_toml( | ||
| cfg: Dict[str, Any], | ||
| *, | ||
| name: Optional[str] = None, | ||
| team_id: Optional[str] = None, | ||
| image_tag: Optional[str] = None, | ||
| wandb_api_key: Optional[str] = None, | ||
| hf_token: Optional[str] = None, | ||
| ) -> Dict[str, Any]: | ||
| """Build the /v1/training/runs payload from a prime-rl-style TOML dict. | ||
|
|
||
| Ships the *whole* TOML as `config` so the backend can split it | ||
| per-component (trainer / orchestrator / inference) and bake each | ||
| into the corresponding pod's startup command. Anything outside the | ||
| handful of platform-authoritative overlays (chart-side scrape ports, | ||
| monitor URL, secret name) flows through unchanged — same e2e | ||
| behaviour as `uv run rl @ rl.toml`. | ||
|
|
||
| What stays out of `config`: | ||
| - secrets (wandb / hf): materialised into a per-run k8s Secret, | ||
| - run name: lives on the platform's RFTRun row, not the TOML, | ||
| - team_id: links the RFTRun to a team for billing/access scoping, | ||
| - image_tag: chart-level (which prime-rl image to pull). | ||
|
|
||
| Cluster targeting is backend-side (auto-pick first uncordoned). | ||
| """ | ||
| payload: Dict[str, Any] = {"config": cfg} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a full-finetune TOML uses the supported Useful? React with 👍 / 👎. |
||
| if name: | ||
| payload["name"] = name | ||
| if team_id: | ||
| payload["teamId"] = team_id | ||
| if image_tag: | ||
| payload["imageTag"] = image_tag | ||
| if wandb_api_key: | ||
| payload["wandbApiKey"] = wandb_api_key | ||
| if hf_token: | ||
| payload["hfToken"] = hf_token | ||
| return payload | ||
Uh oh!
There was an error while loading. Please reload this page.