Skip to content

Commit 67a7c70

Browse files
author
Francisco
committed
feat(sdk): add wait_until_ready and wait_for_completion helpers
Move the dataset polling loop and training job polling loop out of every user's integration scripts and into the SDK. wait_until_ready raises on failed or timeout. wait_for_completion returns the terminal job on any of completed/failed/cancelled and accepts an optional on_progress callback for live metric streaming.
1 parent c2682db commit 67a7c70

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

src/projectdavid/clients/datasets_client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,48 @@ def retrieve(self, dataset_id: str) -> validator.DatasetRead:
9797
response.raise_for_status()
9898
return validator.DatasetRead.model_validate(response.json())
9999

100+
# ------------------------------------------------------------------
101+
# WAIT UNTIL READY
102+
# ------------------------------------------------------------------
103+
def wait_until_ready(
104+
self,
105+
dataset_id: str,
106+
*,
107+
timeout: float = 300.0,
108+
poll_interval: float = 3.0,
109+
) -> validator.DatasetRead:
110+
"""
111+
Block until the dataset reaches 'active' status.
112+
113+
Polls the dataset every `poll_interval` seconds. Returns the final
114+
DatasetRead once active. Raises RuntimeError if the dataset enters
115+
the 'failed' status, TimeoutError if it has not reached a terminal
116+
state within `timeout` seconds.
117+
"""
118+
import time # local import keeps top-of-file imports minimal
119+
120+
deadline = time.monotonic() + timeout
121+
last_status = None
122+
123+
while time.monotonic() < deadline:
124+
ds = self.retrieve(dataset_id)
125+
if ds.status != last_status:
126+
logging_utility.info("Dataset %s status=%s", dataset_id, ds.status)
127+
last_status = ds.status
128+
129+
if ds.status == "active":
130+
return ds
131+
if ds.status == "failed":
132+
raise RuntimeError(
133+
f"Dataset {dataset_id} preparation failed (status=failed)"
134+
)
135+
time.sleep(poll_interval)
136+
137+
raise TimeoutError(
138+
f"Dataset {dataset_id} did not reach 'active' within {timeout}s "
139+
f"(last status={last_status})"
140+
)
141+
100142
# ------------------------------------------------------------------
101143
# LIST
102144
# ------------------------------------------------------------------

src/projectdavid/clients/training_client.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,68 @@ def cancel(self, job_id: str) -> validator.TrainingJobCancelResponse:
8585
response.raise_for_status()
8686
return validator.TrainingJobCancelResponse.model_validate(response.json())
8787

88+
# ------------------------------------------------------------------
89+
# WAIT FOR COMPLETION
90+
# ------------------------------------------------------------------
91+
def wait_for_completion(
92+
self,
93+
job_id: str,
94+
*,
95+
on_progress=None,
96+
poll_interval: float = 10.0,
97+
timeout: float = 7200.0,
98+
) -> validator.TrainingJobRead:
99+
"""
100+
Block until the training job reaches a terminal state.
101+
102+
Terminal states: 'completed', 'failed', 'cancelled'.
103+
104+
An optional `on_progress` callback receives every distinct metrics
105+
snapshot as it arrives (keyed by the 'step' field). Use it to drive
106+
live progress display without coupling the SDK to any particular
107+
output format.
108+
109+
Example:
110+
def show(m):
111+
print(f"step={m.get('step')} loss={m.get('loss')}")
112+
113+
job = client.training.wait_for_completion(job.id, on_progress=show)
114+
115+
Raises TimeoutError if the job has not reached a terminal state
116+
within `timeout` seconds. Does NOT raise on a 'failed' or 'cancelled'
117+
job — callers inspect the returned job.status themselves, since a
118+
failed job is often still interesting (last_error, partial metrics).
119+
"""
120+
import time # local import keeps top-of-file imports minimal
121+
122+
TERMINAL_STATES = {"completed", "failed", "cancelled"}
123+
deadline = time.monotonic() + timeout
124+
last_step = -1
125+
126+
while time.monotonic() < deadline:
127+
job = self.retrieve(job_id)
128+
129+
metrics = getattr(job, "metrics", None) or {}
130+
step = metrics.get("step")
131+
if on_progress is not None and step is not None and step != last_step:
132+
last_step = step
133+
try:
134+
on_progress(metrics)
135+
except Exception as cb_err:
136+
# Callback failures must not break polling.
137+
logging_utility.warning(
138+
"on_progress callback raised %s: %s",
139+
type(cb_err).__name__,
140+
cb_err,
141+
)
142+
143+
if job.status in TERMINAL_STATES:
144+
return job
145+
146+
time.sleep(poll_interval)
147+
148+
raise TimeoutError(f"Training job {job_id} did not complete within {timeout}s")
149+
88150
# ------------------------------------------------------------------
89151
# DIAGNOSTIC PEEK (Secure Multi-tenant Gateway)
90152
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)