Skip to content

Commit c360aba

Browse files
feat(mlx): real-model distributed DFlash+f_θ engine + verifier adapter + E2E
MLXRestorationDraftEngine (host B: torch DFlash + f_θ + verifier embed/lm_head), MLXRestoringVerifierAdapter (host A: wraps MLXRestoredIncrementalVerifier), and InProcessDFlashProposer. scripts/research/k3_distributed_dflash_e2e_mac.py loads the real models once and asserts the distributed path is byte-identical to greedy (in-process or loopback gRPC). Bridge presets mlx-distributed-dflash-e2e- inproc/-grpc. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 811115d commit c360aba

4 files changed

Lines changed: 587 additions & 0 deletions

File tree

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
"""Real-model glue for the distributed DFlash+f_θ path (ADR 0009 §4 F3).
2+
3+
Implements the two model-bound contracts the framework-agnostic distributed
4+
machinery (``inference_engine.distributed.{dflash_service,fused_decode}``) needs:
5+
6+
* :class:`MLXRestorationDraftEngine` — host B: the torch DFlash drafter + f_θ
7+
projection + verifier embed/lm_head, behind ``RestorationDraftEngine``.
8+
* :class:`MLXRestoringVerifierAdapter` — host A: wraps
9+
``MLXRestoredIncrementalVerifier`` as a ``RestoringVerifier``.
10+
11+
Plus :class:`InProcessDFlashProposer`, a ``RemoteDFlashProposer``-shaped object
12+
that calls a local engine directly (no gRPC) — used for the in-process
13+
byte-identical check.
14+
15+
This module imports mlx + torch + the v04 stack, so it lives in the MLX backend
16+
(not coverage-gated) and is validated end-to-end on-device, not by unit tests.
17+
Reuses the exact fused-path helpers from
18+
``scripts/research/k3_integrated_niah_eval_mac.py`` /
19+
``inference_engine.backends.mlx.fused_specdecode`` so the distributed split is
20+
numerically the same engine.
21+
"""
22+
from __future__ import annotations
23+
24+
from dataclasses import dataclass
25+
from typing import Any, Dict, List, Sequence, Tuple
26+
27+
from inference_engine.distributed.dflash_service import DraftResult, RestoreResult
28+
from inference_engine.distributed.fused_decode import CommitResult
29+
from inference_engine.distributed.tensor_codec import (
30+
WireTensor,
31+
mlx_to_wire,
32+
torch_to_wire,
33+
wire_to_mlx,
34+
wire_to_torch,
35+
)
36+
37+
38+
# --------------------------------------------------------------------------- #
39+
# Host B: DFlash drafter + f_θ engine
40+
# --------------------------------------------------------------------------- #
41+
@dataclass
42+
class _Session:
43+
ctx_kv: Any = None
44+
45+
46+
class MLXRestorationDraftEngine:
47+
"""``RestorationDraftEngine`` backed by a torch DFlash drafter + f_θ, using
48+
the MLX verifier's embedding for ``embed_fn``/``lm_head_fn`` (host A and B
49+
share the verifier weights in-process; for a true split host B replicates the
50+
embedding). Per-session drafter context K/V is held here (host B)."""
51+
52+
def __init__(
53+
self,
54+
*,
55+
mlx_model: Any,
56+
text_model: Any,
57+
drafter: Any,
58+
f_theta: Any,
59+
embed_scale: float,
60+
device: Any,
61+
sink: int,
62+
window: int,
63+
force_f_theta: bool = True,
64+
) -> None:
65+
import torch
66+
67+
from inference_engine.backends.mlx.cross_model_dlm_verifier import (
68+
kv_source_layer_map,
69+
mlx_full_attention_layer_indices,
70+
)
71+
from inference_engine.backends.mlx.fused_specdecode import (
72+
make_bridge_embed_lm_head,
73+
)
74+
from scripts.research.k3_dflash_mlx_bridge import mx_to_torch, torch_to_mx
75+
76+
self._torch = torch
77+
self.mlx_model = mlx_model
78+
self.text_model = text_model
79+
self.drafter = drafter
80+
self.f_theta = f_theta
81+
self.fcfg = f_theta.config
82+
self.embed_scale = float(embed_scale)
83+
self.device = device
84+
self.sink = int(sink)
85+
self.window = int(window)
86+
self.force_f_theta = bool(force_f_theta)
87+
self.n_layers = len(text_model.layers)
88+
self.exact_set = set(mlx_full_attention_layer_indices(text_model))
89+
self.src_map = kv_source_layer_map(text_model)
90+
self._mx_to_torch = mx_to_torch
91+
self._torch_to_mx = torch_to_mx
92+
93+
softcap = None
94+
for obj in (getattr(mlx_model, "language_model", None), mlx_model):
95+
cap = getattr(obj, "final_logit_softcapping", None) if obj is not None else None
96+
if cap:
97+
softcap = float(cap)
98+
break
99+
self._embed_fn, self._lm_head_fn = make_bridge_embed_lm_head(
100+
text_model, mx_to_torch=mx_to_torch, torch_to_mx=torch_to_mx,
101+
device=device, torch_dtype=torch.float32, softcap=softcap)
102+
self._sessions: Dict[str, _Session] = {}
103+
104+
# --- prompt-time restoration (capture_drafter_kv + f_θ) ---------------- #
105+
def _capture_drafter_kv(self, ids: Sequence[int]):
106+
import mlx.core as mx
107+
108+
torch = self._torch
109+
ids_mx = mx.array([list(ids)])
110+
emb_mx = self.text_model.embed_tokens(ids_mx)
111+
embedded = self._mx_to_torch(emb_mx, dtype=torch.float32, device=self.device)
112+
layers = list(self.drafter.layers)
113+
k_cap: List[Any] = [None] * len(layers)
114+
v_cap: List[Any] = [None] * len(layers)
115+
handles = []
116+
for i, layer in enumerate(layers):
117+
a = layer.self_attn
118+
handles.append(a.k_proj.register_forward_hook(
119+
lambda m, inp, out, i=i: k_cap.__setitem__(i, out.detach())))
120+
handles.append(a.v_proj.register_forward_hook(
121+
lambda m, inp, out, i=i: v_cap.__setitem__(i, out.detach())))
122+
try:
123+
with torch.no_grad():
124+
T = embedded.size(1)
125+
qpos = torch.arange(T, device=self.device)
126+
h = embedded
127+
for layer in layers:
128+
h = layer(h, qpos, ctx_k=None, ctx_v=None)
129+
finally:
130+
for hh in handles:
131+
hh.remove()
132+
dh, ddim = self.fcfg.drafter_num_kv_heads, self.fcfg.drafter_head_dim
133+
d_k = [k_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))]
134+
d_v = [v_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))]
135+
return d_k, d_v
136+
137+
def restore(
138+
self, session_id: str, prompt_ids: Sequence[int], *,
139+
sink: int, window: int, s5_exact_full_attn: bool, model_id: str,
140+
) -> RestoreResult:
141+
from inference_engine.v04.kv_merge import compute_evicted_positions
142+
143+
torch = self._torch
144+
self._sessions[session_id] = _Session()
145+
prompt_ids = list(prompt_ids)
146+
T = len(prompt_ids)
147+
evicted = compute_evicted_positions(T, self.sink, self.window)
148+
restored: List[Tuple[int, WireTensor, WireTensor]] = []
149+
# S5 free lunch: with native exact-layer prefill and no force, the
150+
# verifier owns all needed K/V and nothing is shipped.
151+
if not (s5_exact_full_attn and not self.force_f_theta):
152+
d_k, d_v = self._capture_drafter_kv(prompt_ids)
153+
with torch.no_grad():
154+
vk, vv = self.f_theta.forward_kv_pack(d_k, d_v)
155+
for li in range(self.n_layers):
156+
if self.src_map[li] != li:
157+
continue
158+
if s5_exact_full_attn and li in self.exact_set:
159+
continue # native cache owns exact (full-attn) layers
160+
k_mx = self._torch_to_mx(vk[li])
161+
v_mx = self._torch_to_mx(vv[li])
162+
restored.append((li, mlx_to_wire(k_mx), mlx_to_wire(v_mx)))
163+
return RestoreResult(restored=restored, evicted_positions=list(evicted),
164+
prompt_len=T)
165+
166+
def seed_context(
167+
self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int],
168+
) -> int:
169+
torch = self._torch
170+
aux_t = [wire_to_torch(w).to(self.device) for w in aux]
171+
pos = torch.tensor(list(positions), device=self.device)
172+
ctx = self.drafter.make_context_kv(aux_t, pos)
173+
self._sessions[session_id].ctx_kv = ctx
174+
return len(positions)
175+
176+
def draft_block(
177+
self, session_id: str, *, bonus_token_id: int, context_len: int,
178+
block_size: int,
179+
) -> DraftResult:
180+
if block_size <= 0:
181+
raise ValueError("block_size must be positive")
182+
sess = self._sessions[session_id]
183+
drafts = self.drafter.draft_block_cached(
184+
sess.ctx_kv, int(bonus_token_id), self._embed_fn, self._lm_head_fn,
185+
block_size=block_size, context_len=int(context_len))
186+
return DraftResult(draft_token_ids=[int(t) for t in drafts],
187+
forward_passes=1, peak_activation_bytes=0)
188+
189+
def extend_context(
190+
self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int],
191+
) -> int:
192+
torch = self._torch
193+
sess = self._sessions[session_id]
194+
aux_t = [wire_to_torch(w).to(self.device) for w in aux]
195+
pos = torch.tensor(list(positions), device=self.device)
196+
new_kv = self.drafter.make_context_kv(aux_t, pos)
197+
sess.ctx_kv = self.drafter.extend_context_kv(sess.ctx_kv, new_kv)
198+
return int(positions[-1]) + 1 if len(positions) else context_len_unknown()
199+
200+
def close_session(self, session_id: str) -> None:
201+
self._sessions.pop(session_id, None)
202+
203+
204+
def context_len_unknown() -> int: # pragma: no cover - defensive; positions never empty
205+
return 0
206+
207+
208+
# --------------------------------------------------------------------------- #
209+
# Host A: MLX verifier adapter
210+
# --------------------------------------------------------------------------- #
211+
class MLXRestoringVerifierAdapter:
212+
"""``RestoringVerifier`` over ``MLXRestoredIncrementalVerifier``."""
213+
214+
def __init__(
215+
self, *, adapter: Any, mlx_model: Any, aux_layer_ids: Sequence[int],
216+
embed_scale: float, bridge: Any, prefill_chunk_size: int = 512,
217+
) -> None:
218+
import mlx.core as mx
219+
220+
self._mx = mx
221+
self.adapter = adapter
222+
self.mlx_model = mlx_model
223+
self.aux_layer_ids = tuple(int(a) for a in aux_layer_ids)
224+
self.embed_scale = float(embed_scale)
225+
self.bridge = bridge
226+
self.prefill_chunk_size = int(prefill_chunk_size)
227+
self._prompt: List[int] = []
228+
self._cstart = 0
229+
self._prev = None
230+
self._block_logits = None
231+
self._candidate: List[int] = []
232+
233+
@property
234+
def context_len(self) -> int:
235+
return self.adapter._past_len
236+
237+
def prefill(
238+
self, prompt_ids: Sequence[int],
239+
restored: Sequence[Tuple[int, WireTensor, WireTensor]],
240+
evicted_positions: Sequence[int],
241+
) -> None:
242+
self._prompt = list(prompt_ids)
243+
rk: Dict[int, Any] = {}
244+
rv: Dict[int, Any] = {}
245+
for layer, k_w, v_w in restored:
246+
rk[layer] = wire_to_mlx(k_w)
247+
rv[layer] = wire_to_mlx(v_w)
248+
self.adapter.prefill(
249+
self._prompt, restored_k_per_layer=rk, restored_v_per_layer=rv,
250+
evicted_positions=list(evicted_positions),
251+
prefill_chunk_size=self.prefill_chunk_size, full_kv=False)
252+
self.adapter._capture_aux = True
253+
254+
def aux_over_prompt(self) -> List[WireTensor]:
255+
from inference_engine.backends.mlx.fused_specdecode import capture_aux_hidden
256+
257+
aux_mx = capture_aux_hidden(
258+
self.mlx_model, self._prompt, self.aux_layer_ids,
259+
embed_scale=self.embed_scale)
260+
return [torch_to_wire(self.bridge(a)) for a in aux_mx]
261+
262+
def next_greedy(self) -> int:
263+
return int(self._mx.argmax(self.adapter.next_token_logits).item())
264+
265+
def verify_block(self, candidate: Sequence[int]) -> int:
266+
mx = self._mx
267+
candidate = list(candidate)
268+
self._cstart = self.adapter._past_len
269+
self._prev = self.adapter.next_token_logits
270+
self._block_logits = self.adapter.forward_block(candidate)
271+
self._candidate = candidate
272+
accepted = 0
273+
running = self._prev
274+
for i, tok in enumerate(candidate):
275+
if int(mx.argmax(running).item()) != tok:
276+
break
277+
accepted += 1
278+
running = self._block_logits[i]
279+
self._running = running
280+
return accepted
281+
282+
def commit(self, accepted: int) -> CommitResult:
283+
torch_cat = __import__("torch").cat
284+
cand = self._candidate
285+
n_aux = len(self.aux_layer_ids)
286+
self.adapter.commit_or_truncate(forwarded=len(cand), accepted=accepted)
287+
cand_aux = self.adapter.last_aux_torch_slice(0, accepted)
288+
if accepted == len(cand):
289+
self.adapter.next_token_logits = self._block_logits[-1]
290+
tokens = list(cand)
291+
new_aux = [torch_cat([cand_aux[li]], dim=0).unsqueeze(0) for li in range(n_aux)]
292+
else:
293+
correction = int(self._mx.argmax(self._running).item())
294+
self.adapter.append_token(correction)
295+
corr_aux = self.adapter.last_aux_torch_slice(0, 1)
296+
tokens = list(cand[:accepted]) + [correction]
297+
new_aux = [
298+
torch_cat([cand_aux[li], corr_aux[li]], dim=0).unsqueeze(0)
299+
for li in range(n_aux)
300+
]
301+
positions = list(range(self._cstart, self._cstart + len(tokens)))
302+
aux_wires = [torch_to_wire(a) for a in new_aux]
303+
return CommitResult(tokens=tokens, aux=aux_wires, positions=positions, stop=False)
304+
305+
306+
# --------------------------------------------------------------------------- #
307+
# In-process proposer (no gRPC) for the byte-identical check
308+
# --------------------------------------------------------------------------- #
309+
class InProcessDFlashProposer:
310+
"""``RemoteDFlashProposer``-shaped wrapper calling a local engine directly."""
311+
312+
def __init__(self, engine: MLXRestorationDraftEngine, *, session_id: str = "inproc",
313+
sink: int = 4, window: int = 64) -> None:
314+
self.engine = engine
315+
self.session_id = session_id
316+
self.sink = sink
317+
self.window = window
318+
319+
def restore(self, prompt_ids, *, sink, window, s5_exact_full_attn=True) -> RestoreResult:
320+
return self.engine.restore(
321+
self.session_id, prompt_ids, sink=sink, window=window,
322+
s5_exact_full_attn=s5_exact_full_attn, model_id="")
323+
324+
def seed_context(self, aux, positions) -> int:
325+
return self.engine.seed_context(self.session_id, aux, positions)
326+
327+
def draft_block(self, *, bonus_token_id, context_len, block_size) -> DraftResult:
328+
return self.engine.draft_block(
329+
self.session_id, bonus_token_id=bonus_token_id,
330+
context_len=context_len, block_size=block_size)
331+
332+
def extend_context(self, aux, positions) -> int:
333+
return self.engine.extend_context(self.session_id, aux, positions)
334+
335+
def close(self) -> None:
336+
self.engine.close_session(self.session_id)

inference_engine/bridge/manifest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,56 @@ def _harness_preset(
104104
PRESETS: Dict[str, Preset] = {
105105
p.name: p
106106
for p in (
107+
Preset(
108+
name="mlx-distributed-dflash-e2e-inproc",
109+
description="Real-model distributed DFlash+f_θ E2E (in-process): loads "
110+
"the gemma-4 mlx-4bit verifier + torch DFlash + f_θ ONCE, "
111+
"runs the DistributedFusedDecoder over an in-process "
112+
"engine (full restore/seed/draft/verify/commit/extend + "
113+
"WireTensor codec), and asserts byte-identical to greedy. "
114+
"Validates the F3 data plane with real models, no 2x load.",
115+
command_templates=(
116+
(
117+
"python3", "scripts/research/k3_distributed_dflash_e2e_mac.py",
118+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
119+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
120+
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
121+
"--max-new-tokens", "{max_new_tokens}",
122+
"--block-size", "{block_size}",
123+
),
124+
),
125+
timeout_minutes=90,
126+
params={
127+
"max_new_tokens": ("int:max_new_tokens", "48"),
128+
"block_size": ("int:block_size", "4"),
129+
},
130+
validate_reports=False,
131+
),
132+
Preset(
133+
name="mlx-distributed-dflash-e2e-grpc",
134+
description="Like mlx-distributed-dflash-e2e-inproc but routes the "
135+
"proposer through a real loopback gRPC DFlashProposerService "
136+
"(--grpc): exercises the wire (Restore/SeedContext/DraftBlock/"
137+
"ExtendContext over gRPC + WireTensor (de)serialization) and "
138+
"measures loopback RTT, still asserting byte-identical.",
139+
command_templates=(
140+
(
141+
"python3", "scripts/research/k3_distributed_dflash_e2e_mac.py",
142+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
143+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
144+
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
145+
"--max-new-tokens", "{max_new_tokens}",
146+
"--block-size", "{block_size}",
147+
"--grpc",
148+
),
149+
),
150+
timeout_minutes=90,
151+
params={
152+
"max_new_tokens": ("int:max_new_tokens", "48"),
153+
"block_size": ("int:block_size", "4"),
154+
},
155+
validate_reports=False,
156+
),
107157
Preset(
108158
name="mlx-distributed-spec-decode-demo",
109159
description="ADR 0009 distributed spec-decode, on-device: two local "

0 commit comments

Comments
 (0)