Skip to content

[codex] DeepSeek FP4 MTP decode safeguards and MLA hooks#779

Draft
josusanmartin wants to merge 1 commit into
ROCm:mainfrom
josusanmartin:codex/dsr1-fp4-mtp-mi355x-submission
Draft

[codex] DeepSeek FP4 MTP decode safeguards and MLA hooks#779
josusanmartin wants to merge 1 commit into
ROCm:mainfrom
josusanmartin:codex/dsr1-fp4-mtp-mi355x-submission

Conversation

@josusanmartin
Copy link
Copy Markdown

ATOM PR Draft

Title

DeepSeek FP4 MTP decode safeguards and guarded small-batch MLA path

Summary

This PR contains DeepSeek R1 FP4 + MTP changes developed for the AMD MI355X finals environment. The patch focuses on correctness-preserving decode improvements and guarded experiment surfaces:

  • Clamp MTP token rollback during preemption so speculative cleanup cannot corrupt prompt/output state.
  • Add a guarded small-batch MLA stage1+reduce path for q_len=4, qh32, gqa32, FP8 KV decode.
  • Make MLA split count configurable through an env var instead of hard-coding it.
  • Add an opt-in exact TP greedy argmax path that avoids gathering full vocabulary logits in all-greedy decode.
  • Add default-off MTP proposal/diagnostic envs for future exact-verifier experimentation.

Motivation

DeepSeek R1 FP4 with MTP spends substantial time in decode attention/MoE and MTP verifier plumbing. The competition workload uses fixed 8192/1024 random prompts at CONC=4,32,128, so safe decode-path improvements are valuable, but GSM8K correctness must remain intact.

The submitted leaderboard runs using this stack passed GSM8K at all three concurrencies.

Implementation Notes

  • atom/model_engine/scheduler.py: clamp speculative rollback and keep sequence token counts consistent after preemption.
  • atom/model_ops/attention_mla.py: add guarded direct AITER mla_decode_stage1_asm_fwd + mla_reduce_v1 path and scratch cache.
  • atom/model_engine/model_runner.py: support optional skip_logits path for exact greedy argmax from hidden states.
  • atom/model_ops/embed_head.py: expose local LM-head projection for TP-local argmax.
  • atom/model_ops/rejection_sampler.py: add argmax-only verifier path for exact greedy speculative acceptance.
  • atom/spec_decode/eagle.py and atom/utils/forward_context.py: carry optional completion counts and speculative metadata.
  • atom/utils/envs.py: add guarded env vars, all default-off except existing behavior-compatible defaults.

Rule / Correctness Notes

  • MTP remains enabled.
  • The submitted runs keep tail-cheap and blind/no-argmax behavior disabled.
  • The top-k draft selector is exact-verifier-compatible and default-off.
  • The TP greedy no-gather path is default-off in the submitted configuration.

Validation

Submitted leaderboard results:

CONC TP GSM8K Tok/s/GPU Interactivity Median E2E
4 4 0.9378 1233.04 138.94 7.876s
32 4 0.9363 3121.65 43.41 24.719s
128 8 0.9378 3613.33 24.48 42.890s

Known limitation: these submitted numbers are accuracy-safe but do not clear all DeepSeek performance gates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant