Skip to content

Commit d9544c6

Browse files
author
Francisco
committed
fix(training): live progress metrics persist mid-training
Two independent bugs suppressed all but the final PROGRESS emit, so job.metrics stayed None during training and only the post-training summary landed in the DB. 1. unsloth_train.py ProgressEmitter wrote PROGRESS:{...} without a leading newline. HuggingFace tqdm writes its step progress bar with carriage returns (no trailing newline), causing PROGRESS lines to concatenate onto the tqdm output. The worker's parser used line.startswith('PROGRESS:') which failed on the concatenated form. Prepending \\\\n guarantees each PROGRESS emit lands on its own line regardless of tqdm state. 2. worker.py performed DB commits synchronously inside the subprocess stdout read loop. Moved writes to a daemon thread drained from a queue.Queue so stdout reads never block on DB latency. Added attempted/succeeded/failed counters for observability. Live training metrics (step, epoch, loss, learning_rate) now land in job.metrics on every logging step and are visible via client.training.retrieve().
1 parent 69ee9cb commit d9544c6

2 files changed

Lines changed: 114 additions & 18 deletions

File tree

src/api/training/unsloth_train.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ class ProgressEmitter(TrainerCallback):
4747
4848
Output format (one line per logging step):
4949
PROGRESS:{"step": 5, "total_steps": 20, "epoch": 0.25, "loss": 1.423, "learning_rate": 0.0002}
50+
51+
Note on leading newline:
52+
HuggingFace Transformers uses tqdm for its progress bar, which writes
53+
progress updates to stdout without a trailing newline (it uses carriage
54+
returns to update in-place). When our PROGRESS print fires on the same
55+
logging step, its output ends up concatenated to the end of the tqdm
56+
line, e.g.:
57+
58+
5%|▌ | 1/20 [00:03<01:08, 3.61s/it]PROGRESS:{"step": 1, ...}
59+
60+
The downstream parser in worker.py uses line.startswith("PROGRESS:")
61+
which fails on that concatenated form — only the final, clean PROGRESS
62+
emit (after tqdm is done) gets captured.
63+
64+
Prepending "\n" guarantees our PROGRESS line starts on its own line
65+
regardless of what tqdm has done to stdout. The worker's stdout reader
66+
then sees it as an independent line and matches cleanly.
5067
"""
5168

5269
def on_log(self, args, state, control, logs=None, **kwargs):
@@ -59,7 +76,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
5976
"loss": round(logs.get("loss", 0), 4),
6077
"learning_rate": logs.get("learning_rate"),
6178
}
62-
print(f"PROGRESS:{json.dumps(progress)}", flush=True)
79+
print(f"\nPROGRESS:{json.dumps(progress)}", flush=True)
6380

6481

6582
# ──────────────────────────────────────────────────────────────────────────────

src/api/training/worker.py

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
import json
2525
import os
26+
import queue
2627
import socket
28+
import threading
2729
import time
2830

2931
import redis
@@ -58,9 +60,9 @@ def process_job(job_id: str, user_id: str):
5860
5961
Progress feedback:
6062
unsloth_train.py emits PROGRESS:{...} lines on every logging step.
61-
This loop parses those lines and writes them to job.metrics so users
62-
polling client.training.retrieve(job_id) get live loss and step count
63-
rather than a black hole between dispatch and completion.
63+
The read loop parses these and pushes metrics dicts to a queue.
64+
A background writer thread drains the queue and commits to the DB,
65+
so DB latency never stalls the subprocess stdout pipe.
6466
6567
All imports are local to keep the module lightweight and avoid
6668
import-time side effects from SQLAlchemy / ORM modules.
@@ -123,6 +125,7 @@ def _get_samba_client():
123125

124126
cmd = [
125127
"python",
128+
"-u",
126129
"src/api/training/unsloth_train.py",
127130
"--model",
128131
job.base_model,
@@ -135,27 +138,103 @@ def _get_samba_client():
135138
]
136139

137140
process = _subprocess.Popen( # nosec B603
138-
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT, text=True
141+
cmd,
142+
stdout=_subprocess.PIPE,
143+
stderr=_subprocess.STDOUT,
144+
text=True,
145+
bufsize=1,
139146
)
140147

141-
# ─── STDOUT LOOP WITH PROGRESS PARSING ───────────────────────────────
148+
# ─── STDOUT LOOP WITH PROGRESS PARSING (THREADED DB WRITES) ──────────
142149
# unsloth_train.py emits PROGRESS:{...} lines on every logging step.
143-
# We parse these and write them to job.metrics so polling clients
144-
# get live feedback. Non-PROGRESS lines are logged normally.
145-
for line in process.stdout:
146-
line = line.strip()
147-
logging_utility.info(f"[{job_id}] {line}")
150+
# Read loop parses lines and pushes metrics dicts onto a queue.
151+
# A daemon writer thread drains the queue and commits to the DB,
152+
# decoupling slow DB writes from the hot stdout-read path.
153+
metrics_queue: queue.Queue = queue.Queue()
154+
writer_stop = threading.Event()
155+
write_counter = {"attempted": 0, "succeeded": 0, "failed": 0}
156+
157+
def _metrics_writer():
158+
while not writer_stop.is_set() or not metrics_queue.empty():
159+
try:
160+
m = metrics_queue.get(timeout=0.5)
161+
except queue.Empty:
162+
continue
148163

149-
if line.startswith("PROGRESS:"):
164+
write_counter["attempted"] += 1
165+
t_start = _time.time()
166+
_db = _SessionLocal()
150167
try:
151-
metrics = json.loads(line[9:])
152-
job.metrics = metrics
153-
job.updated_at = int(_time.time())
154-
db.commit()
155-
except Exception as parse_err:
168+
_job = (
169+
_db.query(_TrainingJob)
170+
.filter(_TrainingJob.id == job_id)
171+
.first()
172+
)
173+
if _job:
174+
_job.metrics = m
175+
_job.updated_at = int(_time.time())
176+
_db.commit()
177+
write_counter["succeeded"] += 1
178+
logging_utility.info(
179+
f"[{job_id}] DB_WRITE_OK step={m.get('step')} "
180+
f"elapsed={(_time.time() - t_start) * 1000:.1f}ms "
181+
f"attempted={write_counter['attempted']} "
182+
f"succeeded={write_counter['succeeded']}"
183+
)
184+
else:
185+
write_counter["failed"] += 1
186+
logging_utility.warning(
187+
f"[{job_id}] DB_WRITE_MISS — TrainingJob not found"
188+
)
189+
except Exception as write_err:
190+
write_counter["failed"] += 1
156191
logging_utility.warning(
157-
f"[{job_id}] Failed to parse PROGRESS line: {parse_err}"
192+
f"[{job_id}] DB_WRITE_FAIL step={m.get('step')} "
193+
f"err={write_err}"
158194
)
195+
finally:
196+
_db.close()
197+
metrics_queue.task_done()
198+
199+
writer_thread = threading.Thread(
200+
target=_metrics_writer, daemon=True, name=f"metrics_writer_{job_id}"
201+
)
202+
writer_thread.start()
203+
logging_utility.info(f"[{job_id}] Metrics writer thread started")
204+
205+
try:
206+
for line in process.stdout:
207+
line = line.strip()
208+
logging_utility.info(f"[{job_id}] {line}")
209+
210+
if line.startswith("PROGRESS:"):
211+
try:
212+
metrics = json.loads(line[9:])
213+
metrics_queue.put(metrics)
214+
logging_utility.info(
215+
f"[{job_id}] QUEUED step={metrics.get('step')} "
216+
f"qsize={metrics_queue.qsize()}"
217+
)
218+
except Exception as parse_err:
219+
logging_utility.warning(
220+
f"[{job_id}] Failed to parse PROGRESS line: {parse_err}"
221+
)
222+
finally:
223+
logging_utility.info(
224+
f"[{job_id}] Stopping metrics writer — "
225+
f"attempted={write_counter['attempted']} "
226+
f"succeeded={write_counter['succeeded']} "
227+
f"failed={write_counter['failed']} "
228+
f"qsize={metrics_queue.qsize()}"
229+
)
230+
writer_stop.set()
231+
writer_thread.join(timeout=10)
232+
logging_utility.info(
233+
f"[{job_id}] Metrics writer stopped — "
234+
f"final attempted={write_counter['attempted']} "
235+
f"succeeded={write_counter['succeeded']} "
236+
f"failed={write_counter['failed']}"
237+
)
159238
# ─────────────────────────────────────────────────────────────────────
160239

161240
process.wait()

0 commit comments

Comments
 (0)