Skip to content

Commit e6e6543

Browse files
feat(distributed): torch CUDA DFlash+f_θ engine + GPU server + cross-host E2E
TorchRestorationDraftEngine (inference_engine/v04/dflash_distributed_engine.py): the pure-torch RestorationDraftEngine for a GPU host, reusing the CUDA fused machinery (CrossModelDLMRestoredVerifier.project_drafter_kv, Gap-B torch embed). k3_dflash_proposer_server.py serves it. E2E script gains --remote-addr (true cross-host) and uses block_size=1 as the greedy baseline. MLX adapter now filters restored layers to the verifier's KV-source layers (gemma-4 cross-layer sharing). Preset mlx-distributed-dflash-e2e-crosshost (Mac verifier <-> GPU proposer via vast-mapped port). Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 653364d commit e6e6543

6 files changed

Lines changed: 345 additions & 132 deletions

File tree

inference_engine/backends/mlx/dflash_distributed.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,16 @@ def __init__(
229229
self._prev = None
230230
self._block_logits = None
231231
self._candidate: List[int] = []
232+
# gemma-4 shares K/V across layers; the MLX verifier injects restored K/V
233+
# only at "source" layers (src_map[li]==li). A torch host B ships every
234+
# non-exact layer; filter to what THIS verifier consumes.
235+
from inference_engine.backends.mlx.cross_model_dlm_verifier import (
236+
kv_source_layer_map,
237+
resolve_mlx_text_model,
238+
)
239+
_tm = resolve_mlx_text_model(mlx_model)
240+
_src = kv_source_layer_map(_tm)
241+
self._source_layers = {li for li in range(len(_src)) if _src[li] == li}
232242

233243
@property
234244
def context_len(self) -> int:
@@ -243,6 +253,8 @@ def prefill(
243253
rk: Dict[int, Any] = {}
244254
rv: Dict[int, Any] = {}
245255
for layer, k_w, v_w in restored:
256+
if layer not in self._source_layers:
257+
continue # non-source layer (shared K/V) — verifier doesn't inject it
246258
rk[layer] = wire_to_mlx(k_w)
247259
rv[layer] = wire_to_mlx(v_w)
248260
self.adapter.prefill(

inference_engine/bridge/manifest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,30 @@ def _harness_preset(
154154
},
155155
validate_reports=False,
156156
),
157+
Preset(
158+
name="mlx-distributed-dflash-e2e-crosshost",
159+
description="TRUE cross-host: gemma-4 mlx-4bit verifier on THIS Mac ↔ a "
160+
"remote torch DFlash+f_θ DFlashProposerService on a GPU "
161+
"(107.206.71.138:43032, the vast map of the H200's :6006). "
162+
"Runs greedy (block=1) + distributed (block=N) over the wire "
163+
"and asserts byte-identical, reporting real cross-host RTT.",
164+
command_templates=(
165+
(
166+
"python3", "scripts/research/k3_distributed_dflash_e2e_mac.py",
167+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
168+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
169+
"--remote-addr", "107.206.71.138:43032",
170+
"--max-new-tokens", "{max_new_tokens}",
171+
"--block-size", "{block_size}",
172+
),
173+
),
174+
timeout_minutes=90,
175+
params={
176+
"max_new_tokens": ("int:max_new_tokens", "48"),
177+
"block_size": ("int:block_size", "4"),
178+
},
179+
validate_reports=False,
180+
),
157181
Preset(
158182
name="mlx-distributed-spec-decode-demo",
159183
description="ADR 0009 distributed spec-decode, on-device: two local "
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Torch/CUDA ``RestorationDraftEngine`` (ADR 0009 §4 F3, host B on a GPU).
2+
3+
The pure-torch twin of ``inference_engine.backends.mlx.dflash_distributed
4+
.MLXRestorationDraftEngine``: a remote DFlash drafter + f_θ projection that runs
5+
on a CUDA host (no MLX), feeding a gemma-4 MLX verifier on another host. Reuses
6+
the CUDA fused-engine machinery (``CrossModelDLMRestoredVerifier.project_drafter_kv``,
7+
``DFlashDrafter`` context K/V, the Gap-B torch embed/lm_head).
8+
9+
Imports torch + transformers + the v04 stack, so it lives in v04 (not coverage-
10+
gated) and is validated on-device.
11+
"""
12+
from __future__ import annotations
13+
14+
from dataclasses import dataclass
15+
from typing import Any, Dict, List, Sequence, Tuple
16+
17+
from inference_engine.distributed.dflash_service import DraftResult, RestoreResult
18+
from inference_engine.distributed.tensor_codec import (
19+
WireTensor,
20+
torch_to_wire,
21+
wire_to_torch,
22+
)
23+
24+
25+
def build_torch_embed_lm_head(verifier_model, softcap):
26+
"""Gap-B torch embed/lm_head over the verifier's tied embedding (no
27+
×sqrt(hidden) on embed; tied head + final-logit softcap). Mirrors
28+
scripts/research/k3_specdecode_gpu_bench._build_embed_lm_head."""
29+
import torch
30+
import torch.nn.functional as F
31+
32+
emb_w = verifier_model.get_input_embeddings().weight.detach()
33+
head_w = verifier_model.get_output_embeddings().weight.detach()
34+
35+
def embed_fn(ids: torch.Tensor) -> torch.Tensor:
36+
return F.embedding(ids, emb_w).float()
37+
38+
def lm_head_fn(h: torch.Tensor) -> torch.Tensor:
39+
logits = (h.to(head_w.dtype) @ head_w.t()).float()
40+
if softcap:
41+
logits = softcap * torch.tanh(logits / softcap)
42+
return logits
43+
44+
return embed_fn, lm_head_fn
45+
46+
47+
@dataclass
48+
class _Session:
49+
ctx_kv: Any = None
50+
51+
52+
class TorchRestorationDraftEngine:
53+
"""``RestorationDraftEngine`` on a CUDA host: torch DFlash + f_θ + a gemma-4
54+
verifier (used only for its embedding / drafter-KV capture)."""
55+
56+
def __init__(
57+
self, *, verifier_model: Any, drafter: Any, f_theta: Any, device: Any,
58+
sink: int, window: int, force_f_theta: bool = True,
59+
) -> None:
60+
import torch
61+
62+
from inference_engine.v04.cross_model_dlm_verifier import (
63+
CrossModelDLMRestoredVerifier,
64+
full_attention_layer_indices,
65+
)
66+
67+
self._torch = torch
68+
self.device = device
69+
self.sink = int(sink)
70+
self.window = int(window)
71+
self.force_f_theta = bool(force_f_theta)
72+
self.drafter = drafter
73+
self.exact_set = set(full_attention_layer_indices(verifier_model))
74+
self._restored = CrossModelDLMRestoredVerifier(
75+
verifier_model=verifier_model, drafter=drafter, f_theta=f_theta,
76+
sink_size=sink, window_size=window,
77+
exact_layer_indices=self.exact_set)
78+
softcap = None
79+
vcfg = getattr(verifier_model, "config", None)
80+
for attr in ("final_logit_softcapping",):
81+
cap = getattr(vcfg, attr, None) if vcfg is not None else None
82+
if cap is None and vcfg is not None:
83+
cap = getattr(getattr(vcfg, "text_config", None), attr, None)
84+
if cap:
85+
softcap = float(cap)
86+
self._embed_fn, self._lm_head_fn = build_torch_embed_lm_head(
87+
verifier_model, softcap)
88+
self._sessions: Dict[str, _Session] = {}
89+
90+
def restore(
91+
self, session_id: str, prompt_ids: Sequence[int], *,
92+
sink: int, window: int, s5_exact_full_attn: bool, model_id: str,
93+
) -> RestoreResult:
94+
from inference_engine.v04.kv_merge import compute_evicted_positions
95+
96+
torch = self._torch
97+
self._sessions[session_id] = _Session()
98+
prompt_ids = list(prompt_ids)
99+
T = len(prompt_ids)
100+
evicted = compute_evicted_positions(T, self.sink, self.window)
101+
restored: List[Tuple[int, WireTensor, WireTensor]] = []
102+
if not (s5_exact_full_attn and not self.force_f_theta):
103+
ids = torch.tensor([prompt_ids], dtype=torch.long, device=self.device)
104+
with torch.no_grad():
105+
vk, vv = self._restored.project_drafter_kv(ids)
106+
for li in range(len(vk)):
107+
if s5_exact_full_attn and li in self.exact_set:
108+
continue # native cache owns exact (full-attn) layers
109+
restored.append((li, torch_to_wire(vk[li]), torch_to_wire(vv[li])))
110+
return RestoreResult(restored=restored, evicted_positions=list(evicted),
111+
prompt_len=T)
112+
113+
def seed_context(
114+
self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int],
115+
) -> int:
116+
torch = self._torch
117+
aux_t = [wire_to_torch(w).to(self.device) for w in aux]
118+
pos = torch.tensor(list(positions), device=self.device)
119+
self._sessions[session_id].ctx_kv = self.drafter.make_context_kv(aux_t, pos)
120+
return len(positions)
121+
122+
def draft_block(
123+
self, session_id: str, *, bonus_token_id: int, context_len: int,
124+
block_size: int,
125+
) -> DraftResult:
126+
if block_size <= 0:
127+
raise ValueError("block_size must be positive")
128+
sess = self._sessions[session_id]
129+
drafts = self.drafter.draft_block_cached(
130+
sess.ctx_kv, int(bonus_token_id), self._embed_fn, self._lm_head_fn,
131+
block_size=block_size, context_len=int(context_len))
132+
return DraftResult(draft_token_ids=[int(t) for t in drafts],
133+
forward_passes=1, peak_activation_bytes=0)
134+
135+
def extend_context(
136+
self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int],
137+
) -> int:
138+
torch = self._torch
139+
sess = self._sessions[session_id]
140+
aux_t = [wire_to_torch(w).to(self.device) for w in aux]
141+
pos = torch.tensor(list(positions), device=self.device)
142+
new_kv = self.drafter.make_context_kv(aux_t, pos)
143+
sess.ctx_kv = self.drafter.extend_context_kv(sess.ctx_kv, new_kv)
144+
return int(positions[-1]) + 1 if len(positions) else 0
145+
146+
def close_session(self, session_id: str) -> None:
147+
self._sessions.pop(session_id, None)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Serve a remote DFlash+f_θ DFlashProposerService on a CUDA host (ADR 0009 F3).
2+
3+
Loads a torch gemma-4 verifier (for its embedding / drafter-KV capture), the
4+
torch DFlash drafter, and f_θ, wraps them in a TorchRestorationDraftEngine, and
5+
serves the gRPC DFlashProposerService. The gemma-4 MLX verifier on another host
6+
drives it via RemoteDFlashProposer.
7+
"""
8+
from __future__ import annotations
9+
10+
import argparse
11+
import asyncio
12+
import sys
13+
14+
15+
async def main() -> int:
16+
ap = argparse.ArgumentParser()
17+
ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it")
18+
ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash")
19+
ap.add_argument("--f-theta-dir", default="results/research/f_theta_v5_s5_sliding")
20+
ap.add_argument("--bind", default="0.0.0.0:6006")
21+
ap.add_argument("--sink", type=int, default=4)
22+
ap.add_argument("--window", type=int, default=64)
23+
ap.add_argument("--dtype", default="bfloat16")
24+
args = ap.parse_args()
25+
26+
import grpc
27+
import torch
28+
from transformers import AutoModelForCausalLM
29+
30+
from inference_engine.distributed.dflash_service import add_dflash_proposer_service
31+
from inference_engine.v04 import DFlashDrafter, FThetaProjection
32+
from inference_engine.v04.dflash_distributed_engine import TorchRestorationDraftEngine
33+
34+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35+
dtype = getattr(torch, args.dtype)
36+
print(f"[server] loading verifier {args.verifier_id} ({dtype}) on {dev}", file=sys.stderr, flush=True)
37+
verifier = AutoModelForCausalLM.from_pretrained(
38+
args.verifier_id, dtype=dtype, attn_implementation="eager").to(dev).eval()
39+
for p in verifier.parameters():
40+
p.requires_grad_(False)
41+
print(f"[server] loading drafter {args.drafter_id} + f_θ {args.f_theta_dir}", file=sys.stderr, flush=True)
42+
drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype).to(dev).eval()
43+
for p in drafter.parameters():
44+
p.requires_grad_(False)
45+
f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=dev)
46+
47+
engine = TorchRestorationDraftEngine(
48+
verifier_model=verifier, drafter=drafter, f_theta=f_theta, device=dev,
49+
sink=args.sink, window=args.window, force_f_theta=True)
50+
51+
server = grpc.aio.server(options=[
52+
("grpc.max_send_message_length", 512 * 1024 * 1024),
53+
("grpc.max_receive_message_length", 512 * 1024 * 1024)])
54+
add_dflash_proposer_service(server, engine)
55+
server.add_insecure_port(args.bind)
56+
await server.start()
57+
print(f"[server] DFlashProposerService serving on {args.bind} (ready)", file=sys.stderr, flush=True)
58+
await server.wait_for_termination()
59+
return 0
60+
61+
62+
if __name__ == "__main__":
63+
raise SystemExit(asyncio.run(main()))

0 commit comments

Comments
 (0)