Skip to content

Commit bb19cea

Browse files
author
Francisco
committed
feat(training): add live progress feedback to fine-tuning pipeline
- Add ProgressEmitter callback to unsloth_train.py — emits structured PROGRESS:{...} lines to stdout on every logging step containing step, total_steps, epoch, loss, and learning_rate - Update worker.py stdout loop to parse PROGRESS: lines and write metrics to job.metrics with db.commit() on each step — users polling client.training.retrieve(job_id) now get live feedback instead of no visibility between dispatch and completion - Fix HF_HUB_OFFLINE comment — clarified that current value 0 permits downloads; sovereignty guard comment was contradicting the code - Update CI workflow — add Rust toolchain install, cargo cache, and maturin develop --release build step to test job so fc_parser extension is available during unit tests on both Python 3.11 and 3.12
1 parent 682dcf2 commit bb19cea

3 files changed

Lines changed: 79 additions & 21 deletions

File tree

.github/workflows/ci.yml

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,25 @@ jobs:
123123
pip install --require-hashes -r sandbox_reqs_hashed.txt
124124
pip install pytest pytest-cov
125125
126+
- name: "🦀 Install Rust toolchain"
127+
uses: dtolnay/rust-toolchain@stable
128+
129+
- name: "📦 Cache Rust build artifacts"
130+
uses: actions/cache@v4
131+
with:
132+
path: |
133+
~/.cargo/registry
134+
~/.cargo/git
135+
rust/fc_parser/target
136+
key: ${{ runner.os }}-rust-fc-parser-${{ matrix.python-version }}-${{ hashFiles('rust/fc_parser/Cargo.toml', 'rust/fc_parser/src/**') }}
137+
restore-keys: |
138+
${{ runner.os }}-rust-fc-parser-${{ matrix.python-version }}-
139+
140+
- name: "🦀 Build fc_parser Rust extension"
141+
run: |
142+
pip install maturin
143+
cd rust/fc_parser && maturin develop --release
144+
126145
- name: "✅ Run Pytest with Coverage"
127146
run: pytest tests/ --cov=src --cov-report=term-missing
128147

@@ -193,12 +212,6 @@ jobs:
193212

194213
# ──────────────────────────────────────────────────────────────────────────
195214
# 🧪 Build & Publish Staging Images (DEV BRANCH ONLY)
196-
#
197-
# Pushes images tagged :dev-<sha> to Docker Hub.
198-
# These are never tagged :latest and never affect production deployments.
199-
# After a successful build, automatically updates docker-compose.yml in
200-
# projectdavid-platform-staging so the staging environment always tracks
201-
# the latest dev images without manual intervention.
202215
# ──────────────────────────────────────────────────────────────────────────
203216
build_staging:
204217
name: "🧪 Build & Publish Staging Images"
@@ -524,11 +537,6 @@ jobs:
524537
echo "| inference-worker | \`${{ steps.get_version.outputs.VERSION }}\` |" >> $GITHUB_STEP_SUMMARY
525538
echo "| router | \`${{ steps.get_version.outputs.VERSION }}\` |" >> $GITHUB_STEP_SUMMARY
526539
527-
# ── Update projectdavid-platform with new production image tags ─────────
528-
# Updates ALL docker-compose yml files in both the root and
529-
# projectdavid_platform/ directory — two instances of each file.
530-
# Requires PLATFORM_REPO_PAT secret with write access to
531-
# project-david-ai/projectdavid-platform.
532540
- name: "📦 Update projectdavid-platform with production image tags"
533541
env:
534542
PLATFORM_REPO_PAT: ${{ secrets.PLATFORM_REPO_PAT }}
@@ -540,7 +548,6 @@ jobs:
540548
git config user.email "ci@projectdavid.ai"
541549
git config user.name "Project David CI"
542550
543-
# All yml files carrying image tags — root level and package level.
544551
YML_FILES=(
545552
"docker-compose.yml"
546553
"docker-compose.gpu.yml"
@@ -559,14 +566,11 @@ jobs:
559566
echo "Skipping missing file: $FILE"
560567
continue
561568
fi
562-
# Replace pinned semver tags: thanosprime/projectdavid-core-*:X.Y.Z
563569
sed -i "s|thanosprime/projectdavid-core-\([^:]*\):[0-9]*\.[0-9]*\.[0-9]*|thanosprime/projectdavid-core-\1:${VERSION}|g" "$FILE"
564-
# Replace :latest tags if present
565570
sed -i "s|thanosprime/projectdavid-core-\([^:]*\):latest|thanosprime/projectdavid-core-\1:${VERSION}|g" "$FILE"
566571
echo "Updated: $FILE"
567572
done
568573
569-
# Update PINNED_IMAGES.md if it exists
570574
if [ -f "PINNED_IMAGES.md" ]; then
571575
sed -i "s|:[0-9]*\.[0-9]*\.[0-9]*|:${VERSION}|g" PINNED_IMAGES.md
572576
sed -i "s|:latest|:${VERSION}|g" PINNED_IMAGES.md

src/api/training/unsloth_train.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import json
12
import os
23

34
# ─── SOVEREIGNTY GUARD ────────────────────────────────────────────────────────
4-
# Prevent any HuggingFace hub download attempts at runtime.
5-
# Only models already present in the local HF cache are permitted.
6-
# If the requested model is not cached, this will raise a clear error
7-
# rather than attempting a download — enforcing airgap compliance.
5+
# HF_HUB_OFFLINE = "1" enforces cache-only mode — no downloads permitted.
6+
# Set to "0" only if you explicitly want to allow HuggingFace hub downloads.
7+
# For production sovereign deployments this should be "1".
88
os.environ["HF_HUB_OFFLINE"] = "0"
99
# ──────────────────────────────────────────────────────────────────────────────
1010

@@ -15,6 +15,7 @@
1515

1616
import unsloth # noqa: F401 — must precede trl/transformers/peft
1717
from datasets import load_dataset
18+
from transformers import TrainerCallback
1819
from trl import SFTConfig, SFTTrainer
1920
from unsloth import FastLanguageModel, is_bfloat16_supported
2021

@@ -37,6 +38,33 @@
3738
}
3839

3940

41+
# ─── PROGRESS EMITTER ─────────────────────────────────────────────────────────
42+
class ProgressEmitter(TrainerCallback):
43+
"""
44+
Emits structured PROGRESS: lines to stdout on every logging step.
45+
The training worker parses these lines and writes them to job.metrics
46+
so users get live feedback during training instead of a black hole.
47+
48+
Output format (one line per logging step):
49+
PROGRESS:{"step": 5, "total_steps": 20, "epoch": 0.25, "loss": 1.423, "learning_rate": 0.0002}
50+
"""
51+
52+
def on_log(self, args, state, control, logs=None, **kwargs):
53+
if not logs:
54+
return
55+
progress = {
56+
"step": state.global_step,
57+
"total_steps": state.max_steps,
58+
"epoch": round(state.epoch or 0, 3),
59+
"loss": round(logs.get("loss", 0), 4),
60+
"learning_rate": logs.get("learning_rate"),
61+
}
62+
print(f"PROGRESS:{json.dumps(progress)}", flush=True)
63+
64+
65+
# ──────────────────────────────────────────────────────────────────────────────
66+
67+
4068
def main():
4169
parser = argparse.ArgumentParser()
4270
parser.add_argument("--model", required=True)
@@ -138,6 +166,7 @@ def format_prompts(examples):
138166
model=model,
139167
train_dataset=dataset,
140168
processing_class=tokenizer,
169+
callbacks=[ProgressEmitter()],
141170
args=SFTConfig(
142171
dataset_text_field="text",
143172
per_device_train_batch_size=p["per_device_train_batch_size"],

src/api/training/worker.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Responsibilities:
77
- Listen to the Redis training job queue
88
- Dispatch training jobs as direct subprocesses (GPU claimed for duration)
9+
- Parse PROGRESS: lines from unsloth_train.py and write live metrics
10+
to job.metrics so users get feedback during training
911
1012
Ray is NOT used here. inference_worker owns the Ray cluster (HEAD node)
1113
and manages all inference GPU reservations via Ray Serve.
@@ -54,6 +56,12 @@ def process_job(job_id: str, user_id: str):
5456
"""
5557
Core training job logic — runs as a direct subprocess.
5658
59+
Progress feedback:
60+
unsloth_train.py emits PROGRESS:{...} lines on every logging step.
61+
This loop parses those lines and writes them to job.metrics so users
62+
polling client.training.retrieve(job_id) get live loss and step count
63+
rather than a black hole between dispatch and completion.
64+
5765
All imports are local to keep the module lightweight and avoid
5866
import-time side effects from SQLAlchemy / ORM modules.
5967
"""
@@ -130,8 +138,26 @@ def _get_samba_client():
130138
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT, text=True
131139
)
132140

141+
# ─── STDOUT LOOP WITH PROGRESS PARSING ───────────────────────────────
142+
# unsloth_train.py emits PROGRESS:{...} lines on every logging step.
143+
# We parse these and write them to job.metrics so polling clients
144+
# get live feedback. Non-PROGRESS lines are logged normally.
133145
for line in process.stdout:
134-
logging_utility.info(f"[{job_id}] {line.strip()}")
146+
line = line.strip()
147+
logging_utility.info(f"[{job_id}] {line}")
148+
149+
if line.startswith("PROGRESS:"):
150+
try:
151+
metrics = json.loads(line[9:])
152+
job.metrics = metrics
153+
job.updated_at = int(_time.time())
154+
db.commit()
155+
except Exception as parse_err:
156+
logging_utility.warning(
157+
f"[{job_id}] Failed to parse PROGRESS line: {parse_err}"
158+
)
159+
# ─────────────────────────────────────────────────────────────────────
160+
135161
process.wait()
136162

137163
if process.returncode == 0:
@@ -142,7 +168,6 @@ def _get_samba_client():
142168
name=f"FT: {job.base_model}",
143169
base_model=job.base_model,
144170
storage_path=model_rel_path,
145-
# node_id removed — FK references compute_nodes (legacy, Phase 5 drop)
146171
status=_StatusEnum.active,
147172
created_at=int(_time.time()),
148173
updated_at=int(_time.time()),

0 commit comments

Comments
 (0)