Skip to content

Commit dddfc6a

Browse files
ChenhanYuclaude
andcommitted
add: synthetic dataset, regression test, Qwen3-0.6B example
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent f7bb9ad commit dddfc6a

File tree

9 files changed

+1447
-104
lines changed

9 files changed

+1447
-104
lines changed

examples/dataset/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,16 @@ python -m modelopt.torch.utils.plugins.megatron_preprocess_data \
219219
--workers 32 \
220220
--reasoning_content inline
221221
```
222+
223+
## Synthetic Test Dataset
224+
225+
`synthetic_conversations_1k.jsonl` is a 1,000-sample dataset in OpenAI messages format
226+
(900 single-turn + 100 two-turn conversations) covering writing, reasoning, math, coding,
227+
STEM, extraction, humanities, and roleplay categories.
228+
229+
This dataset was synthesized by Claude (Anthropic) and is licensed under Apache-2.0.
230+
It is intended for testing and CI regression — not for production training.
231+
232+
```json
233+
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
234+
```

examples/dataset/synthetic_conversations_1k.jsonl

Lines changed: 1000 additions & 0 deletions
Large diffs are not rendered by default.

tests/examples/speculative_decoding/conftest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ def tiny_conversations_path(tmp_path_factory):
4646

4747
@pytest.fixture(scope="session", autouse=True)
4848
def tiny_daring_anteater_path():
49-
"""Return path to static test data (100 samples in OpenAI messages format).
49+
"""Return path to synthetic test data in OpenAI messages format.
5050
51-
test_data.jsonl contains 100 synthetic samples (80 single-turn, 20 two-turn)
52-
generated by Claude (Anthropic) for testing purposes. Licensed under Apache-2.0.
51+
Uses examples/dataset/synthetic_conversations_1k.jsonl (1000 samples,
52+
900 single-turn + 100 two-turn). Synthesized by Claude (Anthropic),
53+
Apache-2.0 licensed.
5354
"""
54-
return Path(__file__).parent / "test_data.jsonl"
55+
return Path(__file__).parents[3] / "examples" / "dataset" / "synthetic_conversations_1k.jsonl"

tests/examples/speculative_decoding/test_data.jsonl

Lines changed: 0 additions & 100 deletions
This file was deleted.
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""DFlash E2E regression tests.
17+
18+
Tests the full DFlash pipeline using Qwen3-0.6B and the synthetic dataset
19+
(examples/dataset/synthetic_conversations_1k.jsonl). Matches the configuration
20+
in tools/launcher/examples/Qwen/Qwen3-0.6B/hf_online_dflash.yaml.
21+
22+
Convergence baseline (from L40 run):
23+
Step 100 (epoch 0.2): loss=6.59 acc=0.079
24+
Step 500 (epoch 1.0): loss=1.78 acc=0.525
25+
Step 1500 (epoch 3.0): loss=1.11 acc=0.595
26+
"""
27+
28+
import json
29+
import os
30+
31+
import pytest
32+
from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command
33+
34+
DFLASH_YAML = str(
35+
MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "dflash.yaml"
36+
)
37+
38+
CHAT_TEMPLATE = str(
39+
MODELOPT_ROOT
40+
/ "tools"
41+
/ "launcher"
42+
/ "examples"
43+
/ "Qwen"
44+
/ "Qwen3-0.6B"
45+
/ "chat_template_train.jinja"
46+
)
47+
48+
SYNTH_DATA = str(MODELOPT_ROOT / "examples" / "dataset" / "synthetic_conversations_1k.jsonl")
49+
50+
# Match tools/launcher/examples/Qwen/Qwen3-0.6B/hf_online_dflash.yaml
51+
_DFLASH_OVERRIDES = [
52+
f"data.data_path={SYNTH_DATA}",
53+
f"data.chat_template={CHAT_TEMPLATE}",
54+
"training.training_seq_len=512",
55+
"training.per_device_train_batch_size=2",
56+
"training.logging_steps=100",
57+
"training.answer_only_loss=true",
58+
"dflash.dflash_block_size=8",
59+
"dflash.dflash_mask_token_id=151669",
60+
"dflash.dflash_use_torch_compile=False",
61+
"dflash.dflash_architecture_config.num_hidden_layers=2",
62+
]
63+
64+
65+
@pytest.fixture(scope="session")
66+
def qwen3_model_name():
67+
"""Qwen3-0.6B model name (downloaded from HF on first use)."""
68+
return "Qwen/Qwen3-0.6B"
69+
70+
71+
@pytest.fixture(scope="session")
72+
def dflash_output_dir(tmp_path_factory):
73+
return tmp_path_factory.mktemp("dflash_output")
74+
75+
76+
def test_dflash_training(qwen3_model_name, dflash_output_dir):
77+
"""Train DFlash on Qwen3-0.6B and validate loss convergence."""
78+
output_dir = str(dflash_output_dir / "dflash-qwen3-0.6b")
79+
overrides = [
80+
f"model.model_name_or_path={qwen3_model_name}",
81+
f"training.output_dir={output_dir}",
82+
"training.num_train_epochs=3",
83+
"training.save_steps=500",
84+
*_DFLASH_OVERRIDES,
85+
]
86+
87+
run_example_command(
88+
["./launch_train.sh", "--config", DFLASH_YAML, *overrides],
89+
"speculative_decoding",
90+
)
91+
92+
# Verify checkpoint was saved
93+
assert os.path.exists(os.path.join(output_dir, "modelopt_state.pth")) or any(
94+
"checkpoint-" in d
95+
for d in os.listdir(output_dir)
96+
if os.path.isdir(os.path.join(output_dir, d))
97+
)
98+
99+
# Regression: verify loss decreased
100+
trainer_state = os.path.join(output_dir, "trainer_state.json")
101+
assert os.path.exists(trainer_state), "trainer_state.json not found"
102+
with open(trainer_state) as f:
103+
state = json.load(f)
104+
logs = [h for h in state.get("log_history", []) if "loss" in h]
105+
assert len(logs) >= 2, f"Expected at least 2 log entries, got {len(logs)}"
106+
107+
first_loss = float(logs[0]["loss"])
108+
final_loss = float(logs[-1]["loss"])
109+
assert final_loss < first_loss, f"Loss did not decrease: {first_loss:.3f} -> {final_loss:.3f}"
110+
# Sanity: final loss should be reasonable (baseline: ~1.1 on L40)
111+
assert final_loss < 3.0, f"Final loss {final_loss:.3f} too high (expected < 3.0)"
112+
113+
114+
def test_dflash_resume(qwen3_model_name, dflash_output_dir):
115+
"""Resume DFlash training from checkpoint."""
116+
output_dir = str(dflash_output_dir / "dflash-qwen3-0.6b")
117+
overrides = [
118+
f"model.model_name_or_path={qwen3_model_name}",
119+
f"training.output_dir={output_dir}",
120+
"training.num_train_epochs=4",
121+
"training.save_steps=5000",
122+
*_DFLASH_OVERRIDES,
123+
]
124+
125+
run_example_command(
126+
["./launch_train.sh", "--config", DFLASH_YAML, *overrides],
127+
"speculative_decoding",
128+
)
129+
130+
131+
def test_dflash_export(dflash_output_dir):
132+
"""Export DFlash checkpoint to deployment format."""
133+
output_dir = str(dflash_output_dir / "dflash-qwen3-0.6b")
134+
export_dir = str(dflash_output_dir / "dflash-export")
135+
136+
run_example_command(
137+
[
138+
"python",
139+
"./scripts/export_hf_checkpoint.py",
140+
"--model_path",
141+
output_dir,
142+
"--export_path",
143+
export_dir,
144+
],
145+
"speculative_decoding",
146+
)
147+
148+
assert os.path.exists(os.path.join(export_dir, "model.safetensors"))
149+
assert os.path.exists(os.path.join(export_dir, "config.json"))
150+
151+
with open(os.path.join(export_dir, "config.json")) as f:
152+
config = json.load(f)
153+
assert config["architectures"] == ["DFlashDraftModel"]
154+
assert config["model_type"] == "qwen3"
155+
assert "dflash_config" in config
156+
assert "block_size" in config
157+
158+
159+
def test_dflash_ar_validate(dflash_output_dir):
160+
"""AR validation on trained DFlash checkpoint."""
161+
output_dir = str(dflash_output_dir / "dflash-qwen3-0.6b")
162+
163+
run_example_command(
164+
[
165+
"python",
166+
"./scripts/ar_validate.py",
167+
"--model_path",
168+
output_dir,
169+
"--osl",
170+
"10",
171+
"--num_samples",
172+
"3",
173+
"--steps",
174+
"7",
175+
],
176+
"speculative_decoding",
177+
)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Regression check for training jobs.
17+
18+
Reads trainer_state.json from a HuggingFace Trainer checkpoint and validates
19+
that final metrics meet specified thresholds. Used by training scripts to
20+
catch regressions in CI.
21+
22+
Environment variables (all optional — no check if unset):
23+
MAX_FINAL_LOSS: Final loss must be below this value
24+
MIN_FINAL_ACC: Final accuracy must be above this value (any key containing 'acc')
25+
MAX_FINAL_PERPLEXITY: Final perplexity must be below this value
26+
27+
Usage:
28+
python check_regression.py /path/to/output_dir
29+
30+
Or from a shell script:
31+
python common/check_regression.py ${OUTPUT_DIR}
32+
33+
Exit codes:
34+
0 — all checks pass (or no thresholds set)
35+
1 — regression detected
36+
"""
37+
38+
import json
39+
import os
40+
import sys
41+
from glob import glob
42+
43+
44+
def find_trainer_state(output_dir):
45+
"""Find the latest trainer_state.json in the output directory."""
46+
# Check checkpoint subdirs first (sorted by step number)
47+
checkpoint_states = sorted(glob(os.path.join(output_dir, "checkpoint-*", "trainer_state.json")))
48+
if checkpoint_states:
49+
return checkpoint_states[-1]
50+
# Fall back to output_dir itself
51+
direct = os.path.join(output_dir, "trainer_state.json")
52+
if os.path.exists(direct):
53+
return direct
54+
return None
55+
56+
57+
def get_final_metrics(trainer_state_path):
58+
"""Extract final loss and accuracy from trainer_state.json."""
59+
with open(trainer_state_path) as f:
60+
state = json.load(f)
61+
62+
logs = [h for h in state.get("log_history", []) if "loss" in h]
63+
if not logs:
64+
return {}
65+
66+
last = logs[-1]
67+
metrics = {"loss": float(last["loss"])}
68+
69+
# Find any accuracy key (train_acc/parallel_0_step_0, eval_accuracy, etc.)
70+
for key, value in last.items():
71+
if "acc" in key.lower():
72+
metrics["accuracy"] = float(value)
73+
break
74+
75+
# Perplexity if available
76+
if "perplexity" in last:
77+
metrics["perplexity"] = float(last["perplexity"])
78+
79+
return metrics
80+
81+
82+
def check_regression(metrics):
83+
"""Check metrics against environment variable thresholds. Returns (passed, messages)."""
84+
checks = [
85+
(
86+
"MAX_FINAL_LOSS",
87+
"loss",
88+
lambda val, thresh: val <= thresh,
89+
"loss {val:.3f} > threshold {thresh}",
90+
),
91+
(
92+
"MIN_FINAL_ACC",
93+
"accuracy",
94+
lambda val, thresh: val >= thresh,
95+
"acc {val:.3f} < threshold {thresh}",
96+
),
97+
(
98+
"MAX_FINAL_PERPLEXITY",
99+
"perplexity",
100+
lambda val, thresh: val <= thresh,
101+
"perplexity {val:.3f} > threshold {thresh}",
102+
),
103+
]
104+
105+
passed = True
106+
messages = []
107+
108+
for env_var, metric_key, check_fn, fail_msg in checks:
109+
thresh_str = os.environ.get(env_var)
110+
if thresh_str is None:
111+
continue
112+
thresh = float(thresh_str)
113+
val = metrics.get(metric_key)
114+
if val is None:
115+
messages.append(f"WARNING: {env_var} set but '{metric_key}' not found in metrics")
116+
continue
117+
if check_fn(val, thresh):
118+
messages.append(f"PASS: {metric_key}={val:.3f} (threshold: {env_var}={thresh})")
119+
else:
120+
messages.append(f"REGRESSION: {fail_msg.format(val=val, thresh=thresh)}")
121+
passed = False
122+
123+
return passed, messages
124+
125+
126+
def main():
127+
"""Entry point for regression check CLI."""
128+
if len(sys.argv) < 2:
129+
print("Usage: python check_regression.py <output_dir>")
130+
sys.exit(0)
131+
132+
output_dir = sys.argv[1]
133+
134+
# Skip if no thresholds set
135+
if not any(
136+
os.environ.get(v) for v in ["MAX_FINAL_LOSS", "MIN_FINAL_ACC", "MAX_FINAL_PERPLEXITY"]
137+
):
138+
return
139+
140+
trainer_state = find_trainer_state(output_dir)
141+
if not trainer_state:
142+
print(f"WARNING: No trainer_state.json found in {output_dir}, skipping regression check")
143+
return
144+
145+
print(f"=== Regression Check ({trainer_state}) ===")
146+
metrics = get_final_metrics(trainer_state)
147+
if not metrics:
148+
print("No training logs found in trainer_state.json")
149+
return
150+
151+
print(f"Final metrics: {metrics}")
152+
passed, messages = check_regression(metrics)
153+
for msg in messages:
154+
print(f" {msg}")
155+
156+
if not passed:
157+
sys.exit(1)
158+
print("Regression check PASSED")
159+
160+
161+
if __name__ == "__main__":
162+
main()

tools/launcher/common/specdec/dflash_online_training.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,9 @@ for arg in sys.argv[1:]:
142142
echo "No new checkpoints to export in ${OUTPUT_DIR}"
143143
fi
144144
fi
145+
146+
# Regression check (uses env vars MAX_FINAL_LOSS, MIN_FINAL_ACC, etc.)
147+
if [ -n "$OUTPUT_DIR" ]; then
148+
python3 common/check_regression.py "${OUTPUT_DIR}" || true
149+
fi
145150
fi

0 commit comments

Comments
 (0)