-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspec_decode.py
More file actions
143 lines (121 loc) · 5.06 KB
/
Copy pathspec_decode.py
File metadata and controls
143 lines (121 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Distributed speculative decoding glue (ADR 0009 §4.3).
Two layers:
1. :func:`accept_block` — the greedy accept rule as a pure function
over tensors. This is the *same* rule ``SpeculativeDecoder`` applies
in-process; having it as a standalone, weight-free function lets the
Linux CI gate pin the rule's semantics (the correctness-containment
argument for accepting drafts from gossip-discovered peers rests
entirely on this function never changing with draft provenance).
2. :class:`DistributedSpeculativeDecoder` — the v0.2 greedy
spec-decode loop driven by a :class:`RemoteProposer`. It subclasses
``SpeculativeDecoder`` and changes nothing about the loop: the
draft source is the only difference, which is the whole point —
output remains bit-equivalent to local greedy AR decoding (modulo
the lossy sink+window cache, exactly as documented for the local
decoder).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import torch
from inference_engine.distributed.placement import SpecDecodePlacement
from inference_engine.distributed.proposer_service import RemoteProposer
from kv_cache_proposer.speculative import SpeculativeDecoder
@dataclass(frozen=True)
class BlockAcceptance:
"""Outcome of greedy verification of one draft block."""
accepted: int
"""Length of the accepted draft prefix (0..L)."""
correction_or_bonus: int
"""The verifier's preferred token at position (prefix + accepted):
a correction when ``accepted < L`` (guaranteed to differ from the
first rejected draft token), the bonus token when the whole block
was accepted."""
def accept_block(
prev_logits: torch.Tensor,
draft: List[int],
block_logits: torch.Tensor,
) -> BlockAcceptance:
"""Greedy (temperature-0) accept rule for one draft block.
Parameters
----------
prev_logits
Verifier logits ``[V]`` predicting the first draft position
(i.e. ``verifier.next_token_logits`` before the block forward).
draft
The L drafted token ids.
block_logits
Verifier logits ``[L, V]`` from the parallel forward over the
draft; row ``i`` predicts position ``i + 1``.
Accept ``draft[i]`` while ``argmax`` of the running logits equals
it; stop at the first mismatch. Identical to the inline loop in
``SpeculativeDecoder.generate`` (kv_cache_proposer/speculative.py).
"""
if block_logits.dim() != 2:
raise ValueError(
f"block_logits must be [L, V]; got shape {tuple(block_logits.shape)}"
)
if block_logits.shape[0] != len(draft):
raise ValueError(
f"block_logits has {block_logits.shape[0]} rows for a draft of "
f"{len(draft)} tokens"
)
if not draft:
raise ValueError("draft must be non-empty")
accepted = 0
running = prev_logits
for i, draft_token in enumerate(draft):
pred = int(torch.argmax(running).item())
if pred != draft_token:
break
accepted += 1
running = block_logits[i]
return BlockAcceptance(
accepted=accepted,
correction_or_bonus=int(torch.argmax(running).item()),
)
class DistributedSpeculativeDecoder(SpeculativeDecoder):
"""Greedy spec decode with the proposer on another node.
The loop, accept rule, EOS handling, and streaming callback are
inherited unchanged from :class:`SpeculativeDecoder`; only the
draft source differs (a :class:`RemoteProposer` gRPC client). A
remote failure surfaces as ``RemoteProposerError`` from
``generate`` — the verifier's session state is intact, so the
caller may re-plan placement and resume.
"""
def __init__(
self,
proposer: RemoteProposer,
verifier: object,
block_size: int = 16,
num_diffusion_steps: int = 16,
) -> None:
super().__init__(
proposer=proposer, # type: ignore[arg-type] - structural DLMProposer contract
verifier=verifier, # type: ignore[arg-type] - SinkWindowVerifier or MLX drop-in
block_size=block_size,
num_diffusion_steps=num_diffusion_steps,
)
@classmethod
def from_placement(
cls,
placement: SpecDecodePlacement,
verifier: object,
*,
block_size: int = 16,
num_diffusion_steps: int = 16,
timeout_s: float = 60.0,
) -> "DistributedSpeculativeDecoder":
"""Build a decoder from a planned placement + a loaded verifier.
The caller is responsible for having loaded ``verifier`` per
``placement.verifier_model`` on this node (this node should be
``placement.verifier_node``); the proposer side needs no local
state — just the placed node's address and model id.
"""
proposer = RemoteProposer(
placement.proposer_node.grpc_address,
model_id=placement.proposer_model.model_id,
timeout_s=timeout_s,
)
return cls(proposer=proposer, verifier=verifier,
block_size=block_size, num_diffusion_steps=num_diffusion_steps)