Skip to content

Add hc_mult support to DFlash for DeepSeek-V4-Flash#524

Draft
rahul-tuli wants to merge 1 commit into
mainfrom
rtuli/dflash-dsv4-hc-mult
Draft

Add hc_mult support to DFlash for DeepSeek-V4-Flash#524
rahul-tuli wants to merge 1 commit into
mainfrom
rtuli/dflash-dsv4-hc-mult

Conversation

@rahul-tuli
Copy link
Copy Markdown
Collaborator

Purpose

Enable training DFlash speculators for DeepSeek-V4-Flash (DSv4), which uses Manifold-Constrained Hyper-Connection (mHC) with hc_mult=4. DSv4's hidden states are (N, hc_mult * hidden_size) per layer rather than the standard (N, hidden_size), and its checkpoint uses non-standard weight names (embed.weight, head.weight, norm.weight). Without these changes the DFlash pipeline cannot initialize or train a speculator for DSv4.

Description

  • scripts/train.py: Read hc_mult from the verifier config (default 1) and thread it into the draft model's transformer_layer_config. Pass hc_mult * hidden_size as the effective hidden size to both train and val dataloaders so create_collate_fn / create_empty_sample create correctly shaped tensors.
  • src/speculators/models/dflash/core.py:
    • __init__: Read self.hc_mult from config. Expand FC layer input dim to len(target_layer_ids) * hc_mult * hidden_size. Register hc_head_fn, hc_head_base, hc_head_scale buffers when hc_mult > 1.
    • forward: Apply hc_head_project() to collapse verifier_last_hidden_states from (N, hc_mult * hidden_size)(N, hidden_size) before verifier_norm in the loss path (only when hc_mult > 1).
    • load_verifier_weights: New override that handles DSv4's non-standard weight names and loads hc_head parameters. Delegates to super() when hc_mult == 1.
  • src/speculators/models/dflash/utils.py: hc_head_project() added in a prior commit (pure-PyTorch port of vLLM's _hc_head_fused_reference).

All changes are backward-compatible: when hc_mult=1 (all non-DSv4 models), every dimension calculation is algebraically identical to the prior code.

Related Issue

Part of the DFlash DSv4 code changes effort (Diff 1: Hidden States from the companion PRD).

Tests

  • All 264 existing unit tests pass (python -m pytest tests/unit/ -x -q).
  • Verified end-to-end with a smoke test script that initializes a DFlash speculator against the real DSv4 checkpoint, loads weights, and runs a dummy forward pass on CUDA — init, weight loading, and forward all pass.
--- __init__ checks ---
  hc_mult = 4
  FC: Linear(49152, 4096)
  hc_head_fn:    (4, 16384)
  hc_head_base:  (4,)
  hc_head_scale: (1,)
  PASS

--- load_verifier_weights checks ---
  embed_tokens:     (129280, 4096), loaded
  lm_head:          (32000, 4096), loaded
  verifier_lm_head: (32000, 4096), loaded
  verifier_norm:    (4096,), loaded
  hc_head_fn:       (4, 16384), loaded
  hc_head_base:     (4,), loaded
  hc_head_scale:    (1,), loaded
  PASS

--- forward pass checks ---
  hidden_states:                (1, 64, 49152)
  input_ids:                    (1, 64)
  verifier_last_hidden_states:  (1, 64, 16384)
  draft_tokens: (1, 32)
  loss:         29.4409
  metrics keys: ['full_acc', 'loss', 'position 1 acc', ...]
  PASS

I have filled in:

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan/results, such as providing test command and pasting the results.
  • (Optional) The necessary documentation update.
  • I (a human) have written or reviewed the code in this pr to the best of my ability.

Generalize the DFlash speculator model and training data pipeline to
handle verifier models with hc_mult > 1 (e.g. DSv4 where hc_mult=4).
All changes degenerate to current behavior when hc_mult=1.

- Read hc_mult from verifier config and thread it through to the
  draft model's transformer_layer_config
- Expand FC layer input dimension to len(target_layer_ids) * hc_mult *
  hidden_size
- Register hc_head_fn/base/scale buffers when hc_mult > 1
- Apply hc_head projection to verifier_last_hidden_states before
  verifier_norm in the forward pass loss computation path
- Override load_verifier_weights to handle DSv4's non-standard weight
  names (embed.weight, head.weight, norm.weight) and load hc_head
  parameters
- Pass effective_hidden_size (hc_mult * hidden_size) to dataloaders so
  empty sample creation and collation use correct tensor shapes

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 15, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2564a22d-75c0-441c-9288-f6c14048c1cb

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch rtuli/dflash-dsv4-hc-mult

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify
Copy link
Copy Markdown

mergify Bot commented May 15, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to address the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/speculators/blob/main/CONTRIBUTING.md

my-other-github-account pushed a commit to my-other-github-account/speculators that referenced this pull request May 15, 2026
…t#414)

Updates the requirements on
[pytest-mock](https://github.com/pytest-dev/pytest-mock) to permit the
latest version.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/pytest-dev/pytest-mock/releases">pytest-mock's
releases</a>.</em></p>
<blockquote>
<h2>v3.15.1</h2>
<p><em>2025-09-16</em></p>
<ul>
<li><a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/529">#529</a>:
Fixed <code>itertools._tee object has no attribute error</code> -- now
<code>duplicate_iterators=True</code> must be passed to
<code>mocker.spy</code> to duplicate iterators.</li>
</ul>
</blockquote>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/pytest-dev/pytest-mock/blob/main/CHANGELOG.rst">pytest-mock's
changelog</a>.</em></p>
<blockquote>
<h2>3.15.1</h2>
<p><em>2025-09-16</em></p>
<ul>
<li><code>[vllm-project#529](pytest-dev/pytest-mock#529)
&lt;https://github.com/pytest-dev/pytest-mock/issues/529&gt;</code>_:
Fixed <code>itertools._tee object has no attribute error</code> -- now
<code>duplicate_iterators=True</code> must be passed to
<code>mocker.spy</code> to duplicate iterators.</li>
</ul>
<h2>3.15.0</h2>
<p><em>2025-09-04</em></p>
<ul>
<li>Python 3.8 (EOL) is no longer supported.</li>
<li><code>[vllm-project#524](pytest-dev/pytest-mock#524)
&lt;https://github.com/pytest-dev/pytest-mock/pull/524&gt;</code>_:
Added <code>spy_return_iter</code> to <code>mocker.spy</code>, which
contains a duplicate of the return value of the spied method if it is an
<code>Iterator</code>.</li>
</ul>
<h2>3.14.1 (2025-05-26)</h2>
<ul>
<li><code>[vllm-project#503](pytest-dev/pytest-mock#503)
&lt;https://github.com/pytest-dev/pytest-mock/pull/503&gt;</code>_:
Python 3.14 is now officially supported.</li>
</ul>
<h2>3.14.0 (2024-03-21)</h2>
<ul>
<li>
<p><code>[vllm-project#415](pytest-dev/pytest-mock#415)
&lt;https://github.com/pytest-dev/pytest-mock/pull/415&gt;</code>_:
<code>MockType</code> and <code>AsyncMockType</code> can be imported
from <code>pytest_mock</code> for type annotation purposes.</p>
</li>
<li>
<p><code>[vllm-project#420](pytest-dev/pytest-mock#420)
&lt;https://github.com/pytest-dev/pytest-mock/issues/420&gt;</code>_:
Fixed a regression which would cause <code>mocker.patch.object</code> to
not being properly cleared between tests.</p>
</li>
</ul>
<h2>3.13.0 (2024-03-21)</h2>
<ul>
<li><code>[vllm-project#417](pytest-dev/pytest-mock#417)
&lt;https://github.com/pytest-dev/pytest-mock/pull/417&gt;</code>_:
<code>spy</code> now has <code>spy_return_list</code>, which is a list
containing all the values returned by the spied function.</li>
<li><code>pytest-mock</code> now requires
<code>pytest&gt;=6.2.5</code>.</li>
<li><code>[vllm-project#410](pytest-dev/pytest-mock#410)
&lt;https://github.com/pytest-dev/pytest-mock/pull/410&gt;</code><em>:
pytest-mock's <code>setup.py</code> file is removed.
If you relied on this file, e.g. to install pytest using <code>setup.py
install</code>,
please see <code>Why you shouldn't invoke setup.py directly
&lt;https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html#summary&gt;</code></em>
for alternatives.</li>
</ul>
<h2>3.12.0 (2023-10-19)</h2>
<ul>
<li>Added support for Python 3.12.</li>
<li>Dropped support for EOL Python 3.7.</li>
<li><code>mocker.resetall()</code> now also resets mocks created by
<code>mocker.create_autospec</code>
(<code>[vllm-project#390](https://github.com/pytest-dev/pytest-mock/issues/390)</code>_).</li>
</ul>
<p>.. _<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/390">#390</a>:
<a
href="https://redirect.github.com/pytest-dev/pytest-mock/pull/390">pytest-dev/pytest-mock#390</a></p>
<h2>3.11.1 (2023-06-15)</h2>
<p>(This release source code is identical to <code>3.11.0</code> except
a small internal fix to deployment/CI)</p>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/e1b5c62a38c5a05cae614aef3847f240ba50d269"><code>e1b5c62</code></a>
Release 3.15.1</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/184eb190d6be417f5f33727bcbc9704909479498"><code>184eb19</code></a>
Set <code>spy_return_iter</code> only when explicitly requested (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/537">#537</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/4fa0088a0aa85eefb1313bd97adf43889bf1f647"><code>4fa0088</code></a>
[pre-commit.ci] pre-commit autoupdate (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/536">#536</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/f5aff33ce71ed4620acc43dc41cb3b198bcf4cb0"><code>f5aff33</code></a>
Fix test failure with pytest 8+ and verbose mode (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/535">#535</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/adc41873c9d6aa69b87e3f108c93a29c847869aa"><code>adc4187</code></a>
Bump actions/setup-python from 5 to 6 in the github-actions group (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/533">#533</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/95ad5700609aae73c6f767b8cc2ccfb2483e0f5c"><code>95ad570</code></a>
[pre-commit.ci] pre-commit autoupdate (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/532">#532</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/e696bf02c199b1f7d0c48adb450f40e5a75b699a"><code>e696bf0</code></a>
Fix standalone mock support (<a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/531">#531</a>)</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/5b29b03ce9581cfcd867dd6c04a970fb2c861291"><code>5b29b03</code></a>
Fix gen-release-notes script</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/7d22ef4e560351832e60687d8bd15ebe2785ff3b"><code>7d22ef4</code></a>
Merge pull request <a
href="https://redirect.github.com/pytest-dev/pytest-mock/issues/528">#528</a>
from pytest-dev/release-3.15.0</li>
<li><a
href="https://github.com/pytest-dev/pytest-mock/commit/90b29f89e2086c139a7b4fea89202faa192ee5a9"><code>90b29f8</code></a>
Update CHANGELOG for 3.15.0</li>
<li>Additional commits viewable in <a
href="https://github.com/pytest-dev/pytest-mock/compare/v3.14.0...v3.15.1">compare
view</a></li>
</ul>
</details>
<br />


Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
@fynnsu
Copy link
Copy Markdown
Collaborator

fynnsu commented May 20, 2026

The concern I have with this approach is that the transformation logic will have to be mirrored in vllm as well. Maybe we should be doing the transformation in vllm before extracting the hidden states?

Also ideally we could generalize some of this. There is already a base implementation of load_verifier_weights in src/speculators/model.py, if modify that implementation then eagle3, p-eagle, and dflash will all be able to use the code. We've also previously discussed a weight mapping layer to handle non-standard weight names (e.g. a user could pass in a dict mapping custom names to the standard ones). Something like that would let us handle the non-standard naming part better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants