-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add GRPO fine-tuning notebook for JSON invoice extraction using Fireworks Training API #243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| FIREWORKS_API_KEY=your_fireworks_key | ||
| OPENROUTER_API_KEY=your_openrouter_key | ||
| FIREWORKS_ACCOUNT_ID=your_account_id |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # GRPO Fine-tuning on Fireworks Training API | ||
|
|
||
| This project demonstrates how to fine-tune **Qwen3-8B** for structured JSON invoice extraction using GRPO (Group Relative Policy Optimization) via the Fireworks Training API. The training loop runs from a local notebook. The model trains on remote GPUs managed by Fireworks. | ||
|
|
||
| The fine-tuned model scores **82% schema-valid accuracy** on a held-out eval set, beating both the base Qwen3-8B (62%) and GPT-4.1 (58%) on the same task. | ||
|
|
||
|  | ||
|
|
||
| --- | ||
|
|
||
| ## Setup and installations | ||
|
|
||
| **Get API Keys**: | ||
| - [Fireworks AI](https://fireworks.ai) — needed for training and inference. Requires RLOR (training) access. Store it as `FIREWORKS_API_KEY` in a `.env` file. | ||
| - [OpenRouter](https://openrouter.ai) — needed for base model and GPT-4.1 eval. Store it as `OPENROUTER_API_KEY` in a `.env` file. | ||
|
|
||
| Refer to `.env.example` for the structure of the file. You will also need your Fireworks account ID stored as `FIREWORKS_ACCOUNT_ID`. | ||
|
|
||
| **Clone the Fireworks cookbook**: | ||
| ```bash | ||
| git clone https://github.com/fw-ai/cookbook.git | ||
| ``` | ||
|
|
||
| **Install Dependencies**: | ||
|
|
||
| Ensure you have Python 3.10 or later installed. | ||
|
|
||
| ```bash | ||
| uv venv | ||
| source .venv/bin/activate | ||
| uv pip install python-dotenv jsonschema openai fireworks-ai matplotlib | ||
| uv pip install -e "cookbook/training[training]" | ||
| uv pip install eval-protocol nest_asyncio | ||
| ``` | ||
|
|
||
| Select the virtual environment as the kernel in the notebook. | ||
|
|
||
| **Run the notebook**: | ||
|
|
||
| Open and run `grpo_json_extraction.ipynb` end-to-end. The notebook covers: | ||
|
|
||
| 1. Reward function that scores JSON completions against a schema | ||
| 2. Dataset upload to Fireworks | ||
| 3. GRPO training loop against remote GPUs | ||
| 4. Baseline eval on base Qwen3-8B | ||
| 5. Post-training eval on the fine-tuned model | ||
| 6. GPT-4.1 comparison eval | ||
| 7. Inference on the deployed model | ||
|
|
||
| --- | ||
|
|
||
| ## Agent Skill | ||
|
|
||
| The `agent-skill/grpo-finetune/` folder contains a reusable agent skill that wraps the full GRPO fine-tuning pipeline — from reward validation to dataset upload to training to inference — into a single runnable script. | ||
|
|
||
| **What's included**: | ||
|
|
||
| - `run_pipeline.py` — end-to-end pipeline: validates reward, uploads dataset, runs GRPO training, evals the fine-tuned model, and runs sample inference | ||
| - `generate_reward.py` — validates that your `reward.py` satisfies the scoring contract before any GPU spend | ||
| - `agent_demo.py` — runs the deployed fine-tuned model on sample invoices and prints structured extraction results | ||
| - `SKILL.md` — skill definition for Claude Code; describes when and how to trigger the skill | ||
|
|
||
| **Run the pipeline**: | ||
|
|
||
| ```bash | ||
| python agent-skill/grpo-finetune/run_pipeline.py \ | ||
| --train ./train_prompts.jsonl \ | ||
| --eval ./eval_prompts.jsonl \ | ||
| --task invoice-extraction \ | ||
| --output-id <your-model-id> | ||
| ``` | ||
|
|
||
| **Run inference only** (if you already have a deployed model): | ||
|
|
||
| ```bash | ||
| python agent-skill/grpo-finetune/agent_demo.py invoices.txt \ | ||
| --deployment accounts/<account-id>/deployments/<your-model-id> | ||
| ``` | ||
|
|
||
| Sample invoices for testing are in `invoices.txt`. Replace them with your own data. | ||
|
|
||
| --- | ||
|
|
||
| ## 📬 Stay Updated with Our Newsletter! | ||
|
|
||
| **Get a FREE Data Science eBook** 📖 with 150+ essential lessons in Data Science when you subscribe to our newsletter! Stay in the loop with the latest tutorials, insights, and exclusive resources. [Subscribe now!](https://join.dailydoseofds.com) | ||
|
|
||
| [](https://join.dailydoseofds.com) | ||
|
|
||
| --- | ||
|
|
||
| ## Contribution | ||
|
|
||
| Contributions are welcome! Please fork the repository and submit a pull request with your improvements. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| --- | ||
| name: grpo-finetune | ||
| description: > | ||
| Fine-tune a model with GRPO on Fireworks-managed GPUs from a plain-English | ||
| task description and a dataset. Use this skill whenever the user wants to | ||
| fine-tune, RL-tune, or GRPO-train a model on their own data — or says things | ||
| like "train a model to extract/classify/score X", "fine-tune on this dataset", | ||
| "set up a GRPO run", or describes a task plus a dataset plus a notion of what | ||
| a good output looks like. Trigger even when the user does not name GRPO or | ||
| Fireworks explicitly. | ||
| --- | ||
|
|
||
| # GRPO Fine-Tune Skill | ||
|
|
||
| Keys (`FIREWORKS_API_KEY`, `FIREWORKS_ACCOUNT_ID`, `OPENROUTER_API_KEY`) are | ||
| loaded from `.env` in the current directory. No extra setup needed if the | ||
| notebook already ran. | ||
|
|
||
| ## What you do when this skill triggers | ||
|
|
||
| ### 1. Understand the task | ||
| Read the user's description. Sample 3-5 rows from their dataset (`head` the | ||
| `.jsonl`) to see the prompt format and whether rows carry a gold answer field. | ||
|
|
||
| ### 2. Write reward.py | ||
|
|
||
| Use this exact reward — schema-only, same as the notebook. Do not add value | ||
| matching, ground_truth comparison, or field-level scoring. Do not modify it. | ||
|
|
||
| ```python | ||
| import json | ||
| from jsonschema import validate, ValidationError | ||
|
|
||
| SCHEMA = { | ||
| "type": "object", | ||
| "required": ["vendor", "date", "amount", "currency"], | ||
| "properties": { | ||
| "vendor": {"type": "string"}, | ||
| "date": {"type": "string"}, | ||
| "amount": {"type": "number"}, | ||
| "currency": {"type": "string"}, | ||
| }, | ||
| "additionalProperties": False, | ||
| } | ||
|
|
||
| def score(completion: str, row=None) -> float: | ||
| try: | ||
| parsed = json.loads(completion.strip()) | ||
| except (json.JSONDecodeError, ValueError): | ||
| return 0.0 | ||
| try: | ||
| validate(instance=parsed, schema=SCHEMA) | ||
| return 1.0 | ||
| except ValidationError: | ||
| return 0.5 | ||
|
|
||
| SELF_TESTS = [ | ||
| ('{"vendor": "Acme", "date": "2024-01-15", "amount": 1250.0, "currency": "USD"}', None, 1.0), | ||
| ('{"vendor": "Acme", "date": "2024-01-15"}', None, 0.5), | ||
| ("not json", None, 0.0), | ||
| ] | ||
| ``` | ||
|
|
||
| The score contract is: 1.0 = valid JSON with correct schema, 0.5 = valid JSON | ||
| wrong shape, 0.0 = not JSON. This is the only reward logic needed. | ||
|
|
||
| ### 3. Show it and offer the edit | ||
| Show the user `reward.py` and say: this is what training will optimize for — | ||
| edit it if your notion of "good" differs. Wait for their go-ahead. | ||
|
|
||
| ### 4. Validate | ||
| ```bash | ||
| $PYTHON agent-skill/grpo-finetune/generate_reward.py --validate reward.py | ||
| ``` | ||
| Must print `PASS` before proceeding. | ||
|
|
||
| ### 5. Run the pipeline | ||
| ```bash | ||
| $PYTHON agent-skill/grpo-finetune/run_pipeline.py \ | ||
| --train <path-to-train.jsonl> \ | ||
| --eval <path-to-eval.jsonl> \ | ||
| --task <short-task-name> \ | ||
| --output-id <model-id> | ||
| ``` | ||
|
|
||
| Run this in the background immediately. Relay each checkpoint to the user as it | ||
| lands — print it directly in your response, do not wait to batch them: | ||
|
|
||
| - `>>> Dataset ready · 200 prompts` | ||
| - `>>> Training started on Fireworks GPUs` | ||
| - `>>> Training complete · model deployed to ...` | ||
| - `>>> Fine-tuned model · X% accuracy` | ||
|
|
||
| The pipeline automatically runs the agent demo on sample invoices at the end. | ||
|
|
||
|
Comment on lines
+94
to
+95
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the demo step as conditional.
🤖 Prompt for AI Agents |
||
| **Important:** training takes 30-60+ minutes. Use a timeout of at least 7200 | ||
| seconds. Do not use the default 10 minute timeout. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| """ | ||
| agent_demo.py | ||
| ------------- | ||
| Runs the fine-tuned invoice extraction model on sample invoices. | ||
|
|
||
| Usage: | ||
| python agent_demo.py invoices.txt | ||
| python agent_demo.py invoices.txt --deployment accounts/myaccount/deployments/invoice-extractor-v1 | ||
| """ | ||
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
| import re | ||
| import sys | ||
| import time | ||
|
|
||
| from dotenv import load_dotenv | ||
| from openai import OpenAI | ||
|
|
||
| REQUIRED_FIELDS = {"vendor", "date", "amount", "currency"} | ||
|
|
||
| GREEN = "\033[32m" | ||
| RED = "\033[31m" | ||
| CYAN = "\033[36m" | ||
| GRAY = "\033[90m" | ||
| BOLD = "\033[1m" | ||
| RESET = "\033[0m" | ||
| ORANGE = "\033[38;5;214m" | ||
| YELLOW = "\033[38;5;226m" | ||
|
|
||
| DIVIDER = f"{GRAY}{'─' * 72}{RESET}" | ||
| DIVIDER_MID = f"{GRAY}{'╌' * 72}{RESET}" | ||
|
|
||
| BOX_INNER = 66 | ||
|
|
||
|
|
||
| def vlen(s): | ||
| return len(re.sub(r'\033\[[0-9;]*m', '', s)) | ||
|
|
||
|
|
||
| def box_line(content): | ||
| pad = BOX_INNER - vlen(content) | ||
| return f" {ORANGE}║{RESET}{content}{' ' * pad}{ORANGE}║{RESET}" | ||
|
|
||
|
|
||
| def print_banner(deployment): | ||
| lines = [ | ||
| f" {ORANGE}╔{'═' * BOX_INNER}╗{RESET}", | ||
| box_line(f" {BOLD}grpo-extract{RESET} {GRAY}·{RESET} {YELLOW}skill v1.0{RESET}"), | ||
| box_line(f""), | ||
| box_line(f" {GRAY}Model {RESET} {CYAN}Qwen3-8B fine-tuned via GRPO{RESET}"), | ||
| box_line(f" {GRAY}Provider {RESET} {CYAN}Fireworks AI{RESET}"), | ||
| box_line(f" {GRAY}Endpoint {RESET} {CYAN}{deployment}{RESET}"), | ||
| box_line(f" {GRAY}Schema {RESET} {CYAN}vendor · date · amount · currency{RESET}"), | ||
| f" {ORANGE}╚{'═' * BOX_INNER}╝{RESET}", | ||
| ] | ||
| print() | ||
| for line in lines: | ||
| print(line) | ||
| time.sleep(0.06) | ||
| print() | ||
|
|
||
|
|
||
| def extract(client, deployment, invoice_text): | ||
| messages = [ | ||
| {"role": "system", "content": "Extract the following fields from this invoice: vendor, date, amount, currency. /no-think"}, | ||
| {"role": "user", "content": f"{invoice_text}\n\nReturn valid JSON only."}, | ||
| ] | ||
| resp = client.chat.completions.create( | ||
| model=deployment, | ||
| messages=messages, | ||
| temperature=0.0, | ||
| max_tokens=512, | ||
| ) | ||
| content = resp.choices[0].message.content | ||
| if "</think>" in content: | ||
| content = content.split("</think>")[-1].strip() | ||
| content = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip()) | ||
| return json.loads(content) | ||
|
|
||
|
|
||
| def validate(result: dict) -> bool: | ||
| return all(result.get(f) not in (None, "", 0) for f in REQUIRED_FIELDS) | ||
|
Comment on lines
+83
to
+84
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make demo validation match the training/eval schema. This only checks for non-empty keys, so it can mark wrong types or extra fields as Suggested fix import argparse
import json
import os
import re
import sys
import time
from dotenv import load_dotenv
+from jsonschema import ValidationError, validate as validate_json
from openai import OpenAI
REQUIRED_FIELDS = {"vendor", "date", "amount", "currency"}
+SCHEMA = {
+ "type": "object",
+ "required": ["vendor", "date", "amount", "currency"],
+ "properties": {
+ "vendor": {"type": "string"},
+ "date": {"type": "string"},
+ "amount": {"type": "number"},
+ "currency": {"type": "string"},
+ },
+ "additionalProperties": False,
+}
@@
def validate(result: dict) -> bool:
- return all(result.get(f) not in (None, "", 0) for f in REQUIRED_FIELDS)
+ try:
+ validate_json(instance=result, schema=SCHEMA)
+ return True
+ except ValidationError:
+ return False🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def run_agent(filepath: str, deployment: str): | ||
| load_dotenv() | ||
| client = OpenAI( | ||
| api_key=os.environ["FIREWORKS_API_KEY"], | ||
| base_url="https://api.fireworks.ai/inference/v1", | ||
| ) | ||
|
|
||
| print_banner(deployment) | ||
| print(DIVIDER) | ||
|
|
||
| with open(filepath) as f: | ||
| raw_docs = [line.strip() for line in f if line.strip()] | ||
|
|
||
| total = len(raw_docs) | ||
| passed = 0 | ||
|
|
||
| for i, doc in enumerate(raw_docs, 1): | ||
| print(f"\n{GRAY}#{i} of {total}{RESET}") | ||
| print(f"{CYAN}{doc}{RESET}") | ||
| print(DIVIDER_MID) | ||
|
|
||
| t0 = time.time() | ||
| try: | ||
| result = extract(client, deployment, doc) | ||
| elapsed = round(time.time() - t0, 2) | ||
| valid = validate(result) | ||
| if valid: | ||
| passed += 1 | ||
|
|
||
| for field in REQUIRED_FIELDS: | ||
| val = result.get(field, "—") | ||
| print(f" {GRAY}{field:<10}{RESET} {val}") | ||
|
|
||
| print() | ||
| if valid: | ||
| print(f" {GREEN}✓ Schema valid{RESET} {GRAY}{elapsed}s{RESET}") | ||
| else: | ||
| print(f" {RED}✗ Schema mismatch{RESET} {GRAY}{elapsed}s{RESET}") | ||
|
|
||
| except Exception as e: | ||
| print(f" {RED}✗ Error: {e}{RESET}") | ||
|
|
||
| print(DIVIDER) | ||
|
|
||
| pct = round(passed / total * 100) | ||
| print(f"\n {BOLD}Results{RESET} {GREEN}{passed}/{total} valid{RESET} Schema match: {GREEN}{pct}%{RESET}\n") | ||
|
Comment on lines
+97
to
+132
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle an empty If the file is empty or only whitespace, 🧰 Tools🪛 Ruff (0.15.15)[warning] 126-126: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| ap = argparse.ArgumentParser() | ||
| ap.add_argument("invoices", help="path to invoices.txt") | ||
| ap.add_argument("--deployment", default="accounts/<account-id>/deployments/invoice-extractor-v1", | ||
| help="deployed model endpoint") | ||
| args = ap.parse_args() | ||
| run_agent(args.invoices, args.deployment) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the required notebook working directory.
These steps never say to start Jupyter from
grpo-finetuning-qwen3, but the notebook resolves./cookbook/training,./train_prompts.jsonl, and./eval_prompts.jsonlwith process-relative paths. Launching the notebook from the repo root will make the setup and data-loading cells fail immediately.🤖 Prompt for AI Agents