Skip to content

Commit 5d75842

Browse files
K3 smoke: stop misreporting the DFlash drafter (spec-decode-only, not a transformers model)
The DFlash drafter (z-lab/gemma-4-26B-A4B-it-DFlash) declares architectures=['DFlashDraftModel'] but ships no modeling file and no auto_map, and DFlashDraftModel is not a built-in transformers class. AutoModelForCausalLM therefore silently fell back to the base model_type=qwen3, dropping the DFlash weights (fc/hidden_norm) and newly-initialising lm_head/embed_tokens — then ran a standalone forward and reported drafter_forward_ok=true. That signal was meaningless: the block-diffusion drafting protocol was never exercised. Per the model card, DFlash runs only via vLLM (PR #41703) or SGLang speculative decoding. Fix: * _detect_drafter_loadability(): flags spec-decode-only drafters (dflash_config / DFlashDraftModel arch not importable, no auto_map). * _load_drafter(): for such drafters, load the qwen3 backbone ONLY as a labeled memory probe (kind=dflash_backbone_memory_probe, faithful=False). * main(): SKIP the standalone drafter forward for spec-decode-only drafters (stage drafter_forward_skipped + validation_path), instead of running garbage through a misloaded backbone. * summary: drafter_forward_ok=null (n/a) + drafter_faithful_transformers_load + drafter_note + drafter_validation_path, instead of a false true. Faithful DFlash speedup validation is deferred to the vLLM/SGLang path (Block A part 1). Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent aae96aa commit 5d75842

1 file changed

Lines changed: 102 additions & 9 deletions

File tree

scripts/research/k3_feasibility_smoke.py

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,17 +267,78 @@ def _load_verifier_mac(verifier_path: str) -> Dict[str, Any]:
267267
}
268268

269269

270+
def _detect_drafter_loadability(drafter_id: str) -> Dict[str, Any]:
271+
"""Decide whether ``drafter_id`` is a faithful standalone transformers
272+
model or a spec-decode-only drafter (e.g. DFlash).
273+
274+
DFlash drafters declare ``architectures=['DFlashDraftModel']`` (and/or
275+
carry a ``dflash_config`` block) but ship **no modeling file and no
276+
``auto_map``**, and ``DFlashDraftModel`` is not a built-in transformers
277+
class. So ``AutoModelForCausalLM`` silently falls back to the base
278+
``model_type`` (qwen3), dropping the DFlash-specific weights
279+
(``fc``/``hidden_norm``) and newly-initialising ``lm_head``/
280+
``embed_tokens``. The result runs a forward but is NOT the DFlash
281+
drafting protocol — DFlash is only runnable via vLLM (PR #41703) or
282+
SGLang speculative decoding per its model card. This detector lets the
283+
smoke report that honestly instead of emitting a misleading
284+
``drafter_forward_ok=true``.
285+
"""
286+
import transformers
287+
from huggingface_hub import hf_hub_download
288+
289+
try:
290+
cfg_path = hf_hub_download(drafter_id, "config.json")
291+
cfg = json.loads(Path(cfg_path).read_text(encoding="utf-8"))
292+
except Exception as e: # network / gated / missing — fall back to "load it"
293+
return {"specdecode_only": False, "architectures": [],
294+
"reason": None, "detect_error": f"{type(e).__name__}: {e}"}
295+
296+
archs = cfg.get("architectures", []) or []
297+
has_dflash_marker = ("dflash_config" in cfg) or any(
298+
"dflash" in str(a).lower() for a in archs
299+
)
300+
arch_importable = any(hasattr(transformers, a) for a in archs)
301+
specdecode_only = bool(has_dflash_marker and not arch_importable)
302+
reason = None
303+
if specdecode_only:
304+
reason = (
305+
f"architectures={archs} is not loadable as a standalone "
306+
f"transformers model (no auto_map / not a built-in class). "
307+
f"DFlash is a block-diffusion speculative-decoding drafter; run "
308+
f"it via vLLM (PR #41703) or SGLang per the model card. The "
309+
f"transformers path here only loads the qwen3 backbone as a "
310+
f"memory probe and does NOT exercise the DFlash drafting protocol."
311+
)
312+
return {"specdecode_only": specdecode_only, "architectures": archs,
313+
"reason": reason}
314+
315+
270316
def _load_drafter(drafter_id: str, platform: str) -> Dict[str, Any]:
271-
"""Load DFlash drafter. Always via transformers (drafter is small and
272-
PyTorch on both CUDA and MPS handles it without the bf16/MLX
273-
quantization decision."""
317+
"""Load the drafter for the feasibility smoke.
318+
319+
For a faithful standalone transformers drafter this loads it normally
320+
and a real forward is run downstream. For a spec-decode-only drafter
321+
(DFlash — see :func:`_detect_drafter_loadability`) the qwen3 backbone
322+
is still loaded so we can report its resident-memory footprint, but it
323+
is flagged ``specdecode_only`` / ``faithful=False`` so the caller skips
324+
the (meaningless) standalone forward and the report does not claim the
325+
DFlash protocol was exercised.
326+
"""
274327
import torch
275328
from transformers import AutoModelForCausalLM, AutoTokenizer
276329

330+
detect = _detect_drafter_loadability(drafter_id)
277331
print(
278332
f"[k3-smoke] loading drafter ({platform}): {drafter_id}",
279333
file=sys.stderr, flush=True,
280334
)
335+
if detect["specdecode_only"]:
336+
print(f"[k3-smoke] NOTE: {detect['reason']}", file=sys.stderr)
337+
print(
338+
"[k3-smoke] -> loading qwen3 backbone as a MEMORY PROBE ONLY "
339+
"(not a faithful DFlash load; standalone forward will be skipped).",
340+
file=sys.stderr,
341+
)
281342
t0 = time.perf_counter()
282343
tokenizer = AutoTokenizer.from_pretrained(drafter_id, trust_remote_code=True)
283344
if platform == "cuda":
@@ -301,14 +362,22 @@ def _load_drafter(drafter_id: str, platform: str) -> Dict[str, Any]:
301362
model.eval()
302363
elapsed = time.perf_counter() - t0
303364
print(
304-
f"[k3-smoke] drafter loaded in {elapsed:.1f}s",
365+
f"[k3-smoke] drafter loaded in {elapsed:.1f}s"
366+
+ (" (backbone memory probe)" if detect["specdecode_only"] else ""),
305367
file=sys.stderr,
306368
)
307369
return {
308-
"kind": f"transformers_{platform}",
370+
"kind": (
371+
"dflash_backbone_memory_probe" if detect["specdecode_only"]
372+
else f"transformers_{platform}"
373+
),
309374
"model": model,
310375
"tokenizer": tokenizer,
311376
"load_seconds": elapsed,
377+
"faithful": not detect["specdecode_only"],
378+
"specdecode_only": detect["specdecode_only"],
379+
"architectures": detect["architectures"],
380+
"note": detect["reason"],
312381
}
313382

314383

@@ -546,8 +615,23 @@ def main() -> int:
546615
_emit(report, args.output)
547616
return 30
548617

549-
# Drafter forward (if loaded).
550-
if drafter is not None:
618+
# Drafter forward (if loaded). For a spec-decode-only drafter (DFlash)
619+
# a standalone transformers forward is meaningless (the backbone was
620+
# loaded with newly-initialised embeddings), so skip it and record the
621+
# honest reason + the real validation path instead of running garbage.
622+
drafter_specdecode_only = bool(drafter and drafter.get("specdecode_only"))
623+
if drafter is not None and drafter_specdecode_only:
624+
report["stages"].append({
625+
"stage": "drafter_forward_skipped",
626+
"reason": drafter.get("note"),
627+
"validation_path": "vllm_pr_41703_or_sglang",
628+
})
629+
print(
630+
"[k3-smoke] drafter forward SKIPPED (spec-decode-only drafter; "
631+
"validate via vLLM PR #41703 / SGLang — not transformers).",
632+
file=sys.stderr,
633+
)
634+
elif drafter is not None:
551635
try:
552636
draft_metrics = _drafter_forward(
553637
drafter, ver_metrics.get("prompt_token_count"),
@@ -574,11 +658,20 @@ def main() -> int:
574658
"verifier_loadable": True,
575659
"verifier_forward_ok": True,
576660
"drafter_loadable": drafter is not None,
661+
"drafter_faithful_transformers_load": bool(drafter and drafter.get("faithful")),
662+
# None == "not applicable" for a spec-decode-only drafter; bool
663+
# otherwise. Avoids the previous misleading drafter_forward_ok=true
664+
# for a backbone that never ran the DFlash protocol.
577665
"drafter_forward_ok": (
578-
drafter is not None
579-
and report["stages"][-1].get("stage") == "drafter_forward"
666+
None if drafter_specdecode_only else (
667+
drafter is not None
668+
and report["stages"][-1].get("stage") == "drafter_forward"
669+
)
580670
),
581671
}
672+
if drafter_specdecode_only:
673+
report["summary"]["drafter_note"] = drafter.get("note")
674+
report["summary"]["drafter_validation_path"] = "vllm_pr_41703_or_sglang"
582675
_emit(report, args.output)
583676
print("[k3-smoke] PASS", file=sys.stderr)
584677
return 0

0 commit comments

Comments
 (0)