Skip to content

Commit 355c6b7

Browse files
ChenhanYuclaude
andauthored
fix: PTQ 1GPU, export PP divisibility, hidden states conversations key (#1293)
## Summary - **megatron_lm_ptq.yaml**: Qwen3-8B PTQ to single GPU for L40 clusters (TP=1, all tasks) - **quantize.sh**: Auto-find largest PP dividing model's `num_hidden_layers` for export step. Qwen3-8B has 36 layers which isn't divisible by 8, causing `AssertionError` on 8-GPU nodes - **compute_hidden_states_trtllm.py**: Use `messages` with `conversations` fallback, matching the HF version. Fixes `KeyError: 'conversations'` when data uses OpenAI `messages` format ## Test plan - [x] Qwen3-8B PTQ runs on single L40 GPU - [x] Export PP auto-selects valid divisor (36 layers → PP=6 on 8 GPUs, PP=4 on 4 GPUs, PP=1 on 1 GPU) - [x] EAGLE3 offline pipeline reads data with `messages` field 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Dataset input handling now supports multiple field formats for enhanced compatibility. * **Bug Fixes** * Optimized GPU resource allocation during model quantization with improved pipeline parallelism computation. * Updated quantization configuration for more efficient resource utilization. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenhan Yu <chenhany@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 289a239 commit 355c6b7

3 files changed

Lines changed: 23 additions & 12 deletions

File tree

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ async def submit_generates():
256256
for entry in dataset:
257257
conversation_id = entry.get("conversation_id", entry.get("uuid"))
258258

259-
conversations = entry["conversations"]
259+
conversations = entry.get("messages") or entry.get("conversations")
260260
if not conversations or not isinstance(conversations, list):
261261
num_invalid += 1
262262
continue

tools/launcher/common/megatron_lm/quantize/quantize.sh

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,22 @@ TP=${TP:-1} PP=${PP:-1} EP=${EP:-1} ETP=${ETP:-1} ${QUANTIZE_EXE} ${MLM_MODEL_CF
4141
export MLM_EXTRA_ARGS="--mmlu-dataset ${MMLU_DATASET:-/hf-local/cais/mmlu} --fraction 0.01 --lower-bound ${MMLU_LOWER_BOUND:-0.38} --disable-tqdm"
4242
TP=${TP:-1} PP=${PP:-1} EP=${EP:-1} ETP=${ETP:-1} MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${MMLU_EXE} ${MLM_MODEL_CFG}
4343

44-
# Export quantized checkpoint to HF format (PP=all GPUs)
44+
# Export quantized checkpoint to HF format
45+
# Use largest PP <= total GPUs that divides the model's num_hidden_layers
4546
TOTAL_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo ${NUM_GPUS:-1})
46-
echo "=== Exporting ${MLM_MODEL_CFG} ${QUANT_CFG} (PP=${TOTAL_GPUS}) ==="
47+
EXPORT_PP=$(python3 -c "
48+
import json, os
49+
cfg = os.path.join('${HF_MODEL_CKPT}', 'config.json')
50+
n_layers = json.load(open(cfg)).get('num_hidden_layers', 1) if os.path.exists(cfg) else 1
51+
gpus = ${TOTAL_GPUS}
52+
pp = gpus
53+
while pp > 1 and n_layers % pp != 0:
54+
pp -= 1
55+
print(pp)
56+
" 2>/dev/null || echo ${TOTAL_GPUS})
57+
echo "=== Exporting ${MLM_MODEL_CFG} ${QUANT_CFG} (PP=${EXPORT_PP}, ${TOTAL_GPUS} GPUs) ==="
4758
export MLM_EXTRA_ARGS=
48-
TP=1 PP=${TOTAL_GPUS} EP=1 ETP=1 MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${EXPORT_EXE} ${MLM_MODEL_CFG}
59+
TP=1 PP=${EXPORT_PP} EP=1 ETP=1 MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${EXPORT_EXE} ${MLM_MODEL_CFG}
4960
ls ${EXPORT_DIR}
5061
cat ${EXPORT_DIR}/hf_quant_config.json
5162

tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pipeline:
2424
config:
2525
model: Qwen/Qwen3-8B
2626
quant_cfg: NVFP4_DEFAULT_CFG
27-
tp: 8
27+
tp: 1
2828
calib_dataset: abisee/cnn_dailymail
2929
calib_size: 32
3030
mmlu_dataset: cais/mmlu
@@ -33,15 +33,15 @@ pipeline:
3333
slurm_config:
3434
_factory_: "slurm_factory"
3535
nodes: 1
36-
ntasks_per_node: 8
37-
gpus_per_node: 8
36+
ntasks_per_node: 1
37+
gpus_per_node: 1
3838

3939
task_1:
4040
_target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask
4141
config:
4242
model: Qwen/Qwen3-8B
4343
quant_cfg: FP8_DEFAULT_CFG
44-
tp: 8
44+
tp: 1
4545
calib_dataset: abisee/cnn_dailymail
4646
calib_size: 32
4747
mmlu_dataset: cais/mmlu
@@ -50,18 +50,18 @@ pipeline:
5050
slurm_config:
5151
_factory_: "slurm_factory"
5252
nodes: 1
53-
ntasks_per_node: 8
54-
gpus_per_node: 8
53+
ntasks_per_node: 1
54+
gpus_per_node: 1
5555

5656
# Step 3: TRT-LLM eval MMLU on all exported checkpoints
5757
task_2:
5858
script: common/tensorrt_llm/eval.sh
5959
environment:
6060
- HF_MODEL_CKPT: /scratchspace/export
61-
- TP: "8"
61+
- TP: "1"
6262
- EP: "1"
6363
slurm_config:
6464
_factory_: "slurm_factory"
6565
nodes: 1
6666
ntasks_per_node: 1
67-
gpus_per_node: 8
67+
gpus_per_node: 1

0 commit comments

Comments
 (0)