Skip to content

dflash_generate: draft sampler ignores temperature; speculative decoding distribution diverges from target for temperature > 0 #74

@shaun0927

Description

@shaun0927

Summary

In dflash_generate, the draft sampler is invoked without the user's temperature, while the target sampler does receive it. For temperature > 0 this means the draft is deterministic (greedy argmax) while the target samples stochastically, so the two paths sample from different distributions and the speculative-decoding distribution guarantees do not hold.

dflash/model.py:121 (draft):

block_output_ids[:, 1:] = sample(draft_logits)            # uses default temperature=0.0

dflash/model.py:134 (target):

posterior = sample(output.logits, temperature)

Reproduction (no model required)

Verbatim copy of sample() from model.py:48-54:

import torch
def sample(logits, temperature=0.0):
    if temperature < 1e-5:
        return torch.argmax(logits, dim=-1)
    bsz, seq_len, vocab_size = logits.shape
    logits = logits.view(-1, vocab_size) / temperature
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)

torch.manual_seed(0)
logits = torch.tensor([[[2.0, 1.5, 1.0, 0.5]]])
draft  = sum(int(sample(logits).item() == 0) for _ in range(4000))
target = sum(int(sample(logits, temperature=1.0).item() == 0) for _ in range(4000))
print(draft, target)   # 4000 1894

The draft picks the mode 100% of the time; the target picks it ~47% of the time. Acceptance is decided by token equality (block_output_ids[:, 1:] == posterior[:, :-1]), so this mismatch artificially depresses acceptance for any temperature > 0 and the accepted-token distribution does not match p_target.

Suggested fix

The minimal correctness improvement is to pass temperature to the draft sample():

block_output_ids[:, 1:] = sample(draft_logits, temperature)

This makes draft and target sample under the same scheme. Acceptance is still token-equality (not Leviathan-style rejection), so it would also be helpful to document that dflash_generate provides exact-distribution semantics only for temperature == 0 and approximate semantics otherwise.

Happy to send a PR with the one-line change plus a docstring update if the team agrees.

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions