|
| 1 | +"""Real-model glue for the distributed DFlash+f_θ path (ADR 0009 §4 F3). |
| 2 | +
|
| 3 | +Implements the two model-bound contracts the framework-agnostic distributed |
| 4 | +machinery (``inference_engine.distributed.{dflash_service,fused_decode}``) needs: |
| 5 | +
|
| 6 | +* :class:`MLXRestorationDraftEngine` — host B: the torch DFlash drafter + f_θ |
| 7 | + projection + verifier embed/lm_head, behind ``RestorationDraftEngine``. |
| 8 | +* :class:`MLXRestoringVerifierAdapter` — host A: wraps |
| 9 | + ``MLXRestoredIncrementalVerifier`` as a ``RestoringVerifier``. |
| 10 | +
|
| 11 | +Plus :class:`InProcessDFlashProposer`, a ``RemoteDFlashProposer``-shaped object |
| 12 | +that calls a local engine directly (no gRPC) — used for the in-process |
| 13 | +byte-identical check. |
| 14 | +
|
| 15 | +This module imports mlx + torch + the v04 stack, so it lives in the MLX backend |
| 16 | +(not coverage-gated) and is validated end-to-end on-device, not by unit tests. |
| 17 | +Reuses the exact fused-path helpers from |
| 18 | +``scripts/research/k3_integrated_niah_eval_mac.py`` / |
| 19 | +``inference_engine.backends.mlx.fused_specdecode`` so the distributed split is |
| 20 | +numerically the same engine. |
| 21 | +""" |
| 22 | +from __future__ import annotations |
| 23 | + |
| 24 | +from dataclasses import dataclass |
| 25 | +from typing import Any, Dict, List, Sequence, Tuple |
| 26 | + |
| 27 | +from inference_engine.distributed.dflash_service import DraftResult, RestoreResult |
| 28 | +from inference_engine.distributed.fused_decode import CommitResult |
| 29 | +from inference_engine.distributed.tensor_codec import ( |
| 30 | + WireTensor, |
| 31 | + mlx_to_wire, |
| 32 | + torch_to_wire, |
| 33 | + wire_to_mlx, |
| 34 | + wire_to_torch, |
| 35 | +) |
| 36 | + |
| 37 | + |
| 38 | +# --------------------------------------------------------------------------- # |
| 39 | +# Host B: DFlash drafter + f_θ engine |
| 40 | +# --------------------------------------------------------------------------- # |
| 41 | +@dataclass |
| 42 | +class _Session: |
| 43 | + ctx_kv: Any = None |
| 44 | + |
| 45 | + |
| 46 | +class MLXRestorationDraftEngine: |
| 47 | + """``RestorationDraftEngine`` backed by a torch DFlash drafter + f_θ, using |
| 48 | + the MLX verifier's embedding for ``embed_fn``/``lm_head_fn`` (host A and B |
| 49 | + share the verifier weights in-process; for a true split host B replicates the |
| 50 | + embedding). Per-session drafter context K/V is held here (host B).""" |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + *, |
| 55 | + mlx_model: Any, |
| 56 | + text_model: Any, |
| 57 | + drafter: Any, |
| 58 | + f_theta: Any, |
| 59 | + embed_scale: float, |
| 60 | + device: Any, |
| 61 | + sink: int, |
| 62 | + window: int, |
| 63 | + force_f_theta: bool = True, |
| 64 | + ) -> None: |
| 65 | + import torch |
| 66 | + |
| 67 | + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( |
| 68 | + kv_source_layer_map, |
| 69 | + mlx_full_attention_layer_indices, |
| 70 | + ) |
| 71 | + from inference_engine.backends.mlx.fused_specdecode import ( |
| 72 | + make_bridge_embed_lm_head, |
| 73 | + ) |
| 74 | + from scripts.research.k3_dflash_mlx_bridge import mx_to_torch, torch_to_mx |
| 75 | + |
| 76 | + self._torch = torch |
| 77 | + self.mlx_model = mlx_model |
| 78 | + self.text_model = text_model |
| 79 | + self.drafter = drafter |
| 80 | + self.f_theta = f_theta |
| 81 | + self.fcfg = f_theta.config |
| 82 | + self.embed_scale = float(embed_scale) |
| 83 | + self.device = device |
| 84 | + self.sink = int(sink) |
| 85 | + self.window = int(window) |
| 86 | + self.force_f_theta = bool(force_f_theta) |
| 87 | + self.n_layers = len(text_model.layers) |
| 88 | + self.exact_set = set(mlx_full_attention_layer_indices(text_model)) |
| 89 | + self.src_map = kv_source_layer_map(text_model) |
| 90 | + self._mx_to_torch = mx_to_torch |
| 91 | + self._torch_to_mx = torch_to_mx |
| 92 | + |
| 93 | + softcap = None |
| 94 | + for obj in (getattr(mlx_model, "language_model", None), mlx_model): |
| 95 | + cap = getattr(obj, "final_logit_softcapping", None) if obj is not None else None |
| 96 | + if cap: |
| 97 | + softcap = float(cap) |
| 98 | + break |
| 99 | + self._embed_fn, self._lm_head_fn = make_bridge_embed_lm_head( |
| 100 | + text_model, mx_to_torch=mx_to_torch, torch_to_mx=torch_to_mx, |
| 101 | + device=device, torch_dtype=torch.float32, softcap=softcap) |
| 102 | + self._sessions: Dict[str, _Session] = {} |
| 103 | + |
| 104 | + # --- prompt-time restoration (capture_drafter_kv + f_θ) ---------------- # |
| 105 | + def _capture_drafter_kv(self, ids: Sequence[int]): |
| 106 | + import mlx.core as mx |
| 107 | + |
| 108 | + torch = self._torch |
| 109 | + ids_mx = mx.array([list(ids)]) |
| 110 | + emb_mx = self.text_model.embed_tokens(ids_mx) |
| 111 | + embedded = self._mx_to_torch(emb_mx, dtype=torch.float32, device=self.device) |
| 112 | + layers = list(self.drafter.layers) |
| 113 | + k_cap: List[Any] = [None] * len(layers) |
| 114 | + v_cap: List[Any] = [None] * len(layers) |
| 115 | + handles = [] |
| 116 | + for i, layer in enumerate(layers): |
| 117 | + a = layer.self_attn |
| 118 | + handles.append(a.k_proj.register_forward_hook( |
| 119 | + lambda m, inp, out, i=i: k_cap.__setitem__(i, out.detach()))) |
| 120 | + handles.append(a.v_proj.register_forward_hook( |
| 121 | + lambda m, inp, out, i=i: v_cap.__setitem__(i, out.detach()))) |
| 122 | + try: |
| 123 | + with torch.no_grad(): |
| 124 | + T = embedded.size(1) |
| 125 | + qpos = torch.arange(T, device=self.device) |
| 126 | + h = embedded |
| 127 | + for layer in layers: |
| 128 | + h = layer(h, qpos, ctx_k=None, ctx_v=None) |
| 129 | + finally: |
| 130 | + for hh in handles: |
| 131 | + hh.remove() |
| 132 | + dh, ddim = self.fcfg.drafter_num_kv_heads, self.fcfg.drafter_head_dim |
| 133 | + d_k = [k_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] |
| 134 | + d_v = [v_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] |
| 135 | + return d_k, d_v |
| 136 | + |
| 137 | + def restore( |
| 138 | + self, session_id: str, prompt_ids: Sequence[int], *, |
| 139 | + sink: int, window: int, s5_exact_full_attn: bool, model_id: str, |
| 140 | + ) -> RestoreResult: |
| 141 | + from inference_engine.v04.kv_merge import compute_evicted_positions |
| 142 | + |
| 143 | + torch = self._torch |
| 144 | + self._sessions[session_id] = _Session() |
| 145 | + prompt_ids = list(prompt_ids) |
| 146 | + T = len(prompt_ids) |
| 147 | + evicted = compute_evicted_positions(T, self.sink, self.window) |
| 148 | + restored: List[Tuple[int, WireTensor, WireTensor]] = [] |
| 149 | + # S5 free lunch: with native exact-layer prefill and no force, the |
| 150 | + # verifier owns all needed K/V and nothing is shipped. |
| 151 | + if not (s5_exact_full_attn and not self.force_f_theta): |
| 152 | + d_k, d_v = self._capture_drafter_kv(prompt_ids) |
| 153 | + with torch.no_grad(): |
| 154 | + vk, vv = self.f_theta.forward_kv_pack(d_k, d_v) |
| 155 | + for li in range(self.n_layers): |
| 156 | + if self.src_map[li] != li: |
| 157 | + continue |
| 158 | + if s5_exact_full_attn and li in self.exact_set: |
| 159 | + continue # native cache owns exact (full-attn) layers |
| 160 | + k_mx = self._torch_to_mx(vk[li]) |
| 161 | + v_mx = self._torch_to_mx(vv[li]) |
| 162 | + restored.append((li, mlx_to_wire(k_mx), mlx_to_wire(v_mx))) |
| 163 | + return RestoreResult(restored=restored, evicted_positions=list(evicted), |
| 164 | + prompt_len=T) |
| 165 | + |
| 166 | + def seed_context( |
| 167 | + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], |
| 168 | + ) -> int: |
| 169 | + torch = self._torch |
| 170 | + aux_t = [wire_to_torch(w).to(self.device) for w in aux] |
| 171 | + pos = torch.tensor(list(positions), device=self.device) |
| 172 | + ctx = self.drafter.make_context_kv(aux_t, pos) |
| 173 | + self._sessions[session_id].ctx_kv = ctx |
| 174 | + return len(positions) |
| 175 | + |
| 176 | + def draft_block( |
| 177 | + self, session_id: str, *, bonus_token_id: int, context_len: int, |
| 178 | + block_size: int, |
| 179 | + ) -> DraftResult: |
| 180 | + if block_size <= 0: |
| 181 | + raise ValueError("block_size must be positive") |
| 182 | + sess = self._sessions[session_id] |
| 183 | + drafts = self.drafter.draft_block_cached( |
| 184 | + sess.ctx_kv, int(bonus_token_id), self._embed_fn, self._lm_head_fn, |
| 185 | + block_size=block_size, context_len=int(context_len)) |
| 186 | + return DraftResult(draft_token_ids=[int(t) for t in drafts], |
| 187 | + forward_passes=1, peak_activation_bytes=0) |
| 188 | + |
| 189 | + def extend_context( |
| 190 | + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], |
| 191 | + ) -> int: |
| 192 | + torch = self._torch |
| 193 | + sess = self._sessions[session_id] |
| 194 | + aux_t = [wire_to_torch(w).to(self.device) for w in aux] |
| 195 | + pos = torch.tensor(list(positions), device=self.device) |
| 196 | + new_kv = self.drafter.make_context_kv(aux_t, pos) |
| 197 | + sess.ctx_kv = self.drafter.extend_context_kv(sess.ctx_kv, new_kv) |
| 198 | + return int(positions[-1]) + 1 if len(positions) else context_len_unknown() |
| 199 | + |
| 200 | + def close_session(self, session_id: str) -> None: |
| 201 | + self._sessions.pop(session_id, None) |
| 202 | + |
| 203 | + |
| 204 | +def context_len_unknown() -> int: # pragma: no cover - defensive; positions never empty |
| 205 | + return 0 |
| 206 | + |
| 207 | + |
| 208 | +# --------------------------------------------------------------------------- # |
| 209 | +# Host A: MLX verifier adapter |
| 210 | +# --------------------------------------------------------------------------- # |
| 211 | +class MLXRestoringVerifierAdapter: |
| 212 | + """``RestoringVerifier`` over ``MLXRestoredIncrementalVerifier``.""" |
| 213 | + |
| 214 | + def __init__( |
| 215 | + self, *, adapter: Any, mlx_model: Any, aux_layer_ids: Sequence[int], |
| 216 | + embed_scale: float, bridge: Any, prefill_chunk_size: int = 512, |
| 217 | + ) -> None: |
| 218 | + import mlx.core as mx |
| 219 | + |
| 220 | + self._mx = mx |
| 221 | + self.adapter = adapter |
| 222 | + self.mlx_model = mlx_model |
| 223 | + self.aux_layer_ids = tuple(int(a) for a in aux_layer_ids) |
| 224 | + self.embed_scale = float(embed_scale) |
| 225 | + self.bridge = bridge |
| 226 | + self.prefill_chunk_size = int(prefill_chunk_size) |
| 227 | + self._prompt: List[int] = [] |
| 228 | + self._cstart = 0 |
| 229 | + self._prev = None |
| 230 | + self._block_logits = None |
| 231 | + self._candidate: List[int] = [] |
| 232 | + |
| 233 | + @property |
| 234 | + def context_len(self) -> int: |
| 235 | + return self.adapter._past_len |
| 236 | + |
| 237 | + def prefill( |
| 238 | + self, prompt_ids: Sequence[int], |
| 239 | + restored: Sequence[Tuple[int, WireTensor, WireTensor]], |
| 240 | + evicted_positions: Sequence[int], |
| 241 | + ) -> None: |
| 242 | + self._prompt = list(prompt_ids) |
| 243 | + rk: Dict[int, Any] = {} |
| 244 | + rv: Dict[int, Any] = {} |
| 245 | + for layer, k_w, v_w in restored: |
| 246 | + rk[layer] = wire_to_mlx(k_w) |
| 247 | + rv[layer] = wire_to_mlx(v_w) |
| 248 | + self.adapter.prefill( |
| 249 | + self._prompt, restored_k_per_layer=rk, restored_v_per_layer=rv, |
| 250 | + evicted_positions=list(evicted_positions), |
| 251 | + prefill_chunk_size=self.prefill_chunk_size, full_kv=False) |
| 252 | + self.adapter._capture_aux = True |
| 253 | + |
| 254 | + def aux_over_prompt(self) -> List[WireTensor]: |
| 255 | + from inference_engine.backends.mlx.fused_specdecode import capture_aux_hidden |
| 256 | + |
| 257 | + aux_mx = capture_aux_hidden( |
| 258 | + self.mlx_model, self._prompt, self.aux_layer_ids, |
| 259 | + embed_scale=self.embed_scale) |
| 260 | + return [torch_to_wire(self.bridge(a)) for a in aux_mx] |
| 261 | + |
| 262 | + def next_greedy(self) -> int: |
| 263 | + return int(self._mx.argmax(self.adapter.next_token_logits).item()) |
| 264 | + |
| 265 | + def verify_block(self, candidate: Sequence[int]) -> int: |
| 266 | + mx = self._mx |
| 267 | + candidate = list(candidate) |
| 268 | + self._cstart = self.adapter._past_len |
| 269 | + self._prev = self.adapter.next_token_logits |
| 270 | + self._block_logits = self.adapter.forward_block(candidate) |
| 271 | + self._candidate = candidate |
| 272 | + accepted = 0 |
| 273 | + running = self._prev |
| 274 | + for i, tok in enumerate(candidate): |
| 275 | + if int(mx.argmax(running).item()) != tok: |
| 276 | + break |
| 277 | + accepted += 1 |
| 278 | + running = self._block_logits[i] |
| 279 | + self._running = running |
| 280 | + return accepted |
| 281 | + |
| 282 | + def commit(self, accepted: int) -> CommitResult: |
| 283 | + torch_cat = __import__("torch").cat |
| 284 | + cand = self._candidate |
| 285 | + n_aux = len(self.aux_layer_ids) |
| 286 | + self.adapter.commit_or_truncate(forwarded=len(cand), accepted=accepted) |
| 287 | + cand_aux = self.adapter.last_aux_torch_slice(0, accepted) |
| 288 | + if accepted == len(cand): |
| 289 | + self.adapter.next_token_logits = self._block_logits[-1] |
| 290 | + tokens = list(cand) |
| 291 | + new_aux = [torch_cat([cand_aux[li]], dim=0).unsqueeze(0) for li in range(n_aux)] |
| 292 | + else: |
| 293 | + correction = int(self._mx.argmax(self._running).item()) |
| 294 | + self.adapter.append_token(correction) |
| 295 | + corr_aux = self.adapter.last_aux_torch_slice(0, 1) |
| 296 | + tokens = list(cand[:accepted]) + [correction] |
| 297 | + new_aux = [ |
| 298 | + torch_cat([cand_aux[li], corr_aux[li]], dim=0).unsqueeze(0) |
| 299 | + for li in range(n_aux) |
| 300 | + ] |
| 301 | + positions = list(range(self._cstart, self._cstart + len(tokens))) |
| 302 | + aux_wires = [torch_to_wire(a) for a in new_aux] |
| 303 | + return CommitResult(tokens=tokens, aux=aux_wires, positions=positions, stop=False) |
| 304 | + |
| 305 | + |
| 306 | +# --------------------------------------------------------------------------- # |
| 307 | +# In-process proposer (no gRPC) for the byte-identical check |
| 308 | +# --------------------------------------------------------------------------- # |
| 309 | +class InProcessDFlashProposer: |
| 310 | + """``RemoteDFlashProposer``-shaped wrapper calling a local engine directly.""" |
| 311 | + |
| 312 | + def __init__(self, engine: MLXRestorationDraftEngine, *, session_id: str = "inproc", |
| 313 | + sink: int = 4, window: int = 64) -> None: |
| 314 | + self.engine = engine |
| 315 | + self.session_id = session_id |
| 316 | + self.sink = sink |
| 317 | + self.window = window |
| 318 | + |
| 319 | + def restore(self, prompt_ids, *, sink, window, s5_exact_full_attn=True) -> RestoreResult: |
| 320 | + return self.engine.restore( |
| 321 | + self.session_id, prompt_ids, sink=sink, window=window, |
| 322 | + s5_exact_full_attn=s5_exact_full_attn, model_id="") |
| 323 | + |
| 324 | + def seed_context(self, aux, positions) -> int: |
| 325 | + return self.engine.seed_context(self.session_id, aux, positions) |
| 326 | + |
| 327 | + def draft_block(self, *, bonus_token_id, context_len, block_size) -> DraftResult: |
| 328 | + return self.engine.draft_block( |
| 329 | + self.session_id, bonus_token_id=bonus_token_id, |
| 330 | + context_len=context_len, block_size=block_size) |
| 331 | + |
| 332 | + def extend_context(self, aux, positions) -> int: |
| 333 | + return self.engine.extend_context(self.session_id, aux, positions) |
| 334 | + |
| 335 | + def close(self) -> None: |
| 336 | + self.engine.close_session(self.session_id) |
0 commit comments