Skip to content

Commit 274d41e

Browse files
committed
feat(mtp): port MTP foundation onto main + extract common::mtp::warm_and_decode (step 3.1)
Rebase of the MTP-via-daemon work onto latest main (PRs #213, #210, #208, #207 already merged) plus the first slice of howard0su's PR #214 review request: move MTP orchestration into dflash/src/common/ behind a generic entry point any ModelBackend can call. ## What landed ### Foundation (rebase port, ~5k LOC) - `dflash/src/qwen36/qwen36_mtp.{cpp,h}` (2.3k LOC) — Qwen3.6 native-heads MTP module (Qwen36MtpModule, implements INativeMtp) - `dflash/src/qwen36/qwen36_mtp_graph.{cpp,h}` — MTP head forward graph - `dflash/src/qwen36/qwen36_mtp_loader.cpp` — NextN tensor loader from GGUF - `dflash/src/common/mtp_interface.h` — abstract IMtpModule + flavor mixins - `dflash/src/common/mtp_chain_runner.{cpp,h}` — generic γ-loop runner - `dflash/src/common/{gguf_metadata,gguf_mmap,step_graph,model_backend}.h` + `attn_masks.h` + `dflash_target.h` updates: shared infrastructure - `dflash/src/qwen35/qwen35_backend.{cpp,h}` — extended with optional Qwen36MtpModule, init_mtp_, warm_mtp_for_prompt_, do_mtp_prefill_, do_mtp_decode_ (will be slimmed once orchestrator absorbs them, step 3.3) - `dflash/src/qwen35/qwen35_daemon.{cpp,h}` — DaemonArgs carry MTP fields - `dflash/src/qwen35/qwen35_dflash_target.{cpp,h}` + `qwen35_target_graph.cpp` — hidden-sequence capture path for MTP head warming - `dflash/test/test_dflash.cpp` — daemon dispatch routes `--daemon --mtp-gguf` to run_qwen35_daemon (file-mode harness preserved) - `dflash/scripts/server.py` — `--mtp-gguf`/`--mtp-gamma`/`--mtp-draft-source` CLI flags, MTP-mode spawn-cmd branch, layered on top of mrciffa's thinking-default fixes (commit 998b280) without conflict ### Step 3.1 — common::mtp::warm_and_decode entry point (TDD red→green) Howard's review: > "MTP should be simple as additional weights of modelbackend. If a model > contains MTP support (gemma4 or qwen3.5), the logic can handle it. In > other words, the logic should be in /common which can potentially > leverage by any modelbackend if they support mtp." Carved out the public surface for the future orchestrator: GenerateResult dflash27b::common::mtp::warm_and_decode( ModelBackend * backend, const GenerateRequest & req, const DaemonIO & io); New files: - `dflash/src/common/mtp_orchestrator.{cpp,h}` — header pins the signature, cpp is a minimal stub that only handles guard cases (null backend, no MTP support, empty prompt). Real warm + decode body lands in step 3.2, driven by additional red→green tests. - `dflash/test/test_common_mtp_orchestrator.cpp` — three guard tests written and watched fail BEFORE the stub existed (compile-time RED: "common/mtp_orchestrator.h: No such file or directory"), then GREEN after the stub returned matching error strings. Test results: T1 null_backend PASS T2 backend_without_mtp PASS T3 empty_prompt PASS ALL PASS ## Steps 3.2-3.5 (separate commits, this PR) 3.2 fill warm_and_decode body (chunked prefill via DFlashTarget::verify_batch + hidden capture + MtpChainRunner.run); red test = identical token IDs vs reference run_qwen36_mtp_harness on a fixed prompt. 3.3 replace Qwen35Backend::do_mtp_decode_/do_mtp_prefill_ with calls to common::mtp::warm_and_decode; delete the qwen35-local helpers. 3.4 stub Gemma4Backend MTP override using the same common entry point to prove the interface is generic (not Qwen35-specific). 3.5 audit common/mtp_orchestrator + mtp_chain_runner for any hand-rolled CPU loops; replace with ggml primitives per howard's point #1. Then retest 24K baseline post-RoPE-fix (howard's other comment) and update PR description with current numbers. Addresses: - davide221 #214#issuecomment-4472910706 (merge conflicts) — rebased - howard0su #214#review (changes requested points 2, 3, 4) — first slice
1 parent 7b781b8 commit 274d41e

31 files changed

Lines changed: 6289 additions & 28 deletions

dflash/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ add_library(dflash27b STATIC
236236
src/qwen35/qwen35_layer_split_dflash_target.cpp
237237
src/qwen35/layer_split_daemon_loop.cpp
238238
src/qwen35/qwen35_daemon.cpp
239+
src/qwen36/qwen36_mtp.cpp
240+
src/qwen36/qwen36_mtp_graph.cpp
241+
src/qwen36/qwen36_mtp_loader.cpp
242+
src/common/mtp_chain_runner.cpp
243+
src/common/mtp_orchestrator.cpp
239244
src/common/sampler.cpp
240245
src/common/daemon_loop.cpp
241246
src/common/gguf_inspect.cpp
@@ -491,6 +496,11 @@ if(DFLASH27B_TESTS)
491496
target_include_directories(test_kv_quant PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
492497
target_link_libraries(test_kv_quant PRIVATE dflash27b)
493498
endif()
499+
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp")
500+
add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp)
501+
target_include_directories(test_common_mtp_orchestrator PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
502+
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash27b)
503+
endif()
494504
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
495505
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)
496506
target_link_libraries(test_draft_vs_reference PRIVATE dflash27b)

dflash/scripts/server.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max
734734
verify_mode: str = "ddtree",
735735
extra_daemon_args: list[str] | None = None,
736736
lazy_draft: bool = False,
737-
verbose_daemon: bool = False) -> FastAPI:
737+
verbose_daemon: bool = False,
738+
mtp_gguf: Path | None = None,
739+
mtp_gamma: int = 3,
740+
mtp_draft_source: str = "chain",
741+
mtp_draft_topk: int = 1) -> FastAPI:
738742
import asyncio
739743
if _extra_daemon_has_target_sharding(extra_daemon_args):
740744
if prefix_cache_slots > 0 or prefill_cache_slots > 0:
@@ -791,6 +795,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError
791795
cmd = [bin_abs, str(target), "--daemon",
792796
f"--max-ctx={max_ctx}",
793797
f"--stream-fd={stream_fd_val}"]
798+
elif mtp_gguf is not None:
799+
# MTP mode: no --draft (MTP head lives inside target or mtp_gguf),
800+
# no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf.
801+
cmd = [bin_abs, str(target), "--daemon",
802+
f"--max-ctx={max_ctx}",
803+
f"--stream-fd={stream_fd_val}",
804+
f"--mtp-gguf={mtp_gguf}",
805+
f"--gamma={mtp_gamma}",
806+
"--draft-source", mtp_draft_source]
807+
if mtp_draft_source == "mtp_topk":
808+
cmd.append(f"--draft-topk={mtp_draft_topk}")
809+
if extra_daemon_args:
810+
cmd.extend(extra_daemon_args)
794811
else:
795812
if draft is None:
796813
raise SystemExit("qwen35 arch requires --draft <draft.gguf|model.safetensors>")
@@ -2737,6 +2754,20 @@ def main():
27372754
help="Pass --draft-feature-mirror to test_dflash (safe cross-GPU feature path)")
27382755
ap.add_argument("--peer-access", action="store_true",
27392756
help="Pass --peer-access to test_dflash (prefer P2P memcpy when available)")
2757+
# ── MTP (Multi-Token Prediction) speculator ──────────────────────────────
2758+
# When --mtp-gguf is set, the daemon runs MTP-head speculation instead of
2759+
# DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as
2760+
# target, or a separate fused GGUF). Prefix-cache slots are auto-disabled
2761+
# in MTP mode because RESTORE does not snapshot MTP head KV yet.
2762+
ap.add_argument("--mtp-gguf", type=Path, default=None,
2763+
help="Path to MTP-fused GGUF. When set, daemon runs MTP "
2764+
"speculation; --draft and DFlash flags are ignored.")
2765+
ap.add_argument("--mtp-gamma", type=int, default=3,
2766+
help="MTP chain depth (default 3; recommended D=3 per matrix bench)")
2767+
ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain",
2768+
help="MTP draft generation strategy (default chain)")
2769+
ap.add_argument("--mtp-draft-topk", type=int, default=1,
2770+
help="Top-K for mtp_topk draft source (default 1, ignored for chain)")
27402771
add_cli_flags(ap)
27412772
args = ap.parse_args()
27422773
prefill_cfg = config_from_args(args)
@@ -2782,6 +2813,17 @@ def main():
27822813
# through the laguna daemon now, so --prefill-compression and
27832814
# --prefix-cache-slots behave the same as on the qwen35 path.
27842815
draft = None
2816+
elif args.mtp_gguf is not None:
2817+
# MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf
2818+
# if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot
2819+
# MTP head KV yet (planned for a follow-up PR).
2820+
if not args.mtp_gguf.is_file():
2821+
raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}")
2822+
draft = None
2823+
if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0:
2824+
print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)")
2825+
args.prefix_cache_slots = 0
2826+
args.prefill_cache_slots = 0
27852827
else:
27862828
draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft
27872829
if not draft.is_file():
@@ -2813,7 +2855,11 @@ def main():
28132855
verify_mode=args.verify_mode,
28142856
extra_daemon_args=placement.daemon_args or None,
28152857
lazy_draft=args.lazy_draft,
2816-
verbose_daemon=args.verbose_daemon)
2858+
verbose_daemon=args.verbose_daemon,
2859+
mtp_gguf=args.mtp_gguf,
2860+
mtp_gamma=args.mtp_gamma,
2861+
mtp_draft_source=args.mtp_draft_source,
2862+
mtp_draft_topk=args.mtp_draft_topk)
28172863

28182864
import uvicorn
28192865
logging.basicConfig(

dflash/src/common/attn_masks.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,24 @@ inline void build_causal_mask(std::vector<uint16_t> & out,
4646
// Build an ancestor-only attention mask for DDTree tree-structured verify.
4747
// Each query position i can attend to its ancestors in the tree (including
4848
// itself) plus all past KV positions.
49+
//
50+
// kv_pad_override: when nonzero, pin the kv (column) stride to this value
51+
// instead of the helper's natural `align_up(past_length + N - win_start,
52+
// kq_stride_pad)`. Needed when the consumer tensor was allocated with a
53+
// fixed kv extent (e.g. build_target_step_tree sizes sg.attn_mask at
54+
// align_up(cache.max_ctx + N, kq_stride_pad)) and the helper-computed
55+
// stride would not match the tensor's actual row pitch. Default 0 keeps
56+
// existing behavior.
4957
inline void build_tree_mask(const DDTree & tree, int past_length,
5058
std::vector<uint16_t> & out_mask,
5159
int kq_stride_pad,
52-
int win_start = 0) {
60+
int win_start = 0,
61+
int kv_pad_override = 0) {
5362
const int N = 1 + tree.n_nodes;
5463
const int win_len = past_length + N - win_start;
55-
const int kv_pad = align_up(win_len, kq_stride_pad);
64+
const int kv_pad = kv_pad_override > 0
65+
? align_up(kv_pad_override, kq_stride_pad)
66+
: align_up(win_len, kq_stride_pad);
5667
const int q_pad = align_up(N, KQ_MASK_PAD);
5768
out_mask.assign((size_t)kv_pad * q_pad, F16_NEG_INF);
5869
for (int q = 0; q < N; q++) {

dflash/src/common/dflash_target.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,39 @@
1111

1212
#pragma once
1313

14+
#include "ddtree.h"
15+
1416
#include <cstdint>
1517
#include <vector>
1618

19+
struct ggml_backend;
20+
typedef struct ggml_backend * ggml_backend_t;
21+
struct ggml_tensor;
22+
1723
namespace dflash27b {
1824

1925
struct DFlashTarget {
2026
virtual ~DFlashTarget() = default;
2127

28+
// Return the ggml backend used by this target's graph compute. Default
29+
// returns nullptr; callers (e.g. Qwen3.6 MTP) that want to build CUDA
30+
// cgraphs against the same backend should check this and fall back if
31+
// it's null.
32+
virtual ggml_backend_t backend() const { return nullptr; }
33+
34+
// Optional: return the LM-head weight tensor on the target's backend
35+
// (shape [n_embd, n_vocab], used by ggml_mul_mat). When non-null, the
36+
// Qwen3.6 MTP step graph fuses `mul_mat(W, x_normed) -> argmax` into
37+
// its own cgraph, skipping a hidden -> host -> separate-cgraph round
38+
// trip per step. Default returns nullptr so existing targets (CPU
39+
// stubs) keep the project_hidden_to_* fallback path.
40+
virtual ggml_tensor * lm_head_weight() const { return nullptr; }
41+
42+
// Optional: causal attention window the target's full-attn blocks use
43+
// (kv_len - fa_window). The MTP head uses the same window so it sees
44+
// the same active context. 0 means full causal context.
45+
virtual int fa_window() const { return 0; }
46+
2247
// ── Target forward ──────────────────────────────────────────────
2348

2449
// Run a batch of tokens through the target model. Returns the argmax
@@ -33,6 +58,26 @@ struct DFlashTarget {
3358
int & last_tok,
3459
std::vector<int32_t> * all_argmax = nullptr) = 0;
3560

61+
// Tree-structured verify: run a flat DFS-ordered DDTree through the
62+
// target with an ancestor-only attention mask. `flat_tokens[0]` is the
63+
// tree root (= last accepted token), `flat_tokens[1..N-1]` are the
64+
// DFS-ordered tree nodes (mirroring DDTree::token_ids). `tree.n_nodes`
65+
// must equal flat_tokens.size() - 1. On success, `out_argmax` is the
66+
// target's argmax at each of the N tree positions (size == N) and
67+
// (if non-null) `out_logits` is the raw logits laid out as
68+
// [N × vocab] floats. Returns false by default so existing targets
69+
// that haven't wired tree-verify can be detected by callers; concrete
70+
// targets override to plug in build_target_step_tree + the tree mask.
71+
virtual bool verify_tree(const std::vector<int32_t> & flat_tokens,
72+
const DDTree & tree,
73+
int base_pos,
74+
std::vector<int32_t> & out_argmax,
75+
std::vector<float> * out_logits = nullptr) {
76+
(void)flat_tokens; (void)tree; (void)base_pos;
77+
(void)out_argmax; (void)out_logits;
78+
return false;
79+
}
80+
3681
// ── KV state management ─────────────────────────────────────────
3782

3883
// Snapshot KV cache state before speculative verify, so it can be
@@ -42,6 +87,32 @@ struct DFlashTarget {
4287
// Restore KV cache to the last snapshot (undo speculative forward).
4388
virtual bool restore_kv() = 0;
4489

90+
// Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of
91+
// the most recent verify_tree() call. accepted_dfs[0] must be 0 (root).
92+
// Returns false if unsupported; callers must treat false as fatal in
93+
// multi-iteration tree-spec loops (poisoned KV/SSM otherwise).
94+
virtual bool restore_kv_at_dfs(const std::vector<int> & accepted_dfs) {
95+
(void)accepted_dfs;
96+
return false;
97+
}
98+
99+
// Roll back DeltaNet SSM/conv + full-attn KV to slot `accept_n` of the
100+
// most recent verify_batch chain. Requires chain capture enabled.
101+
// Postcondition: cache cur_pos = base_pos + accept_n + 1.
102+
// Returns false if unsupported; chain runner falls back to snapshot+recommit.
103+
virtual bool restore_kv_at_chain(int accept_n) {
104+
(void)accept_n;
105+
return false;
106+
}
107+
108+
// Enable per-position DeltaNet intermediate capture in verify_batch.
109+
// Off by default; unsafe when n_tokens > max_verify_tokens.
110+
virtual void enable_chain_capture(bool /*on*/) {}
111+
112+
// Record linear-chain topology before verify_batch so restore_kv_at_chain()
113+
// can locate the rollback slot. Must be called before each capturable iter.
114+
virtual void capture_topology_for_chain(int /*n_tokens*/, int /*base_pos*/) {}
115+
45116
// ── Token utilities ─────────────────────────────────────────────
46117

47118
// Check if a token is end-of-sequence for this model.
@@ -61,6 +132,21 @@ struct DFlashTarget {
61132
int n_tokens,
62133
std::vector<int32_t> & tokens_out) = 0;
63134

135+
// Optional: project draft hidden states through the target's lm_head
136+
// and return the full raw logits (n_tokens * vocab floats) on host.
137+
// Used by MTP drafters that need a top-K surface for DDTree (the
138+
// argmax path above hides the distribution). Default returns false so
139+
// existing targets compile unchanged; concrete targets that wire it
140+
// up resize `logits_out` to n_tokens * vocab and return true. The
141+
// `out_vocab` param reports the vocab dim back to the caller.
142+
virtual bool project_hidden_to_logits(const float * /*hidden*/,
143+
int /*n_tokens*/,
144+
std::vector<float> & /*logits_out*/,
145+
int & out_vocab) {
146+
out_vocab = 0;
147+
return false;
148+
}
149+
64150
// ── Configuration for draft model ───────────────────────────────
65151

66152
// Target's hidden dimension (draft model must match).
@@ -72,6 +158,47 @@ struct DFlashTarget {
72158
// Which target layers to capture intermediate activations from.
73159
// The draft model's fc layer expects exactly this many feature slices.
74160
virtual const std::vector<int> & capture_layer_ids() const = 0;
161+
162+
// Return the backbone's final post-norm hidden state for the last committed
163+
// token (hidden_size() floats, F32). Populated by verify_batch.
164+
// Returns nullptr if not yet available (e.g. before first verify_batch).
165+
// Default implementation returns nullptr; Qwen35DFlashTarget overrides it.
166+
virtual const float * last_hidden() const { return nullptr; }
167+
168+
// Return the full post-norm hidden sequence from the MOST RECENT
169+
// verify_batch call: n_tokens * hidden_size() floats, F32, laid out as
170+
// [token_0_hidden, token_1_hidden, ..., token_{n_tokens-1}_hidden].
171+
// *out_n_tokens is set to the number of tokens captured (matches the
172+
// n_tokens passed to verify_batch). Default returns nullptr.
173+
virtual const float * last_hidden_seq(int * out_n_tokens) const {
174+
if (out_n_tokens) *out_n_tokens = 0;
175+
return nullptr;
176+
}
177+
178+
// Return the post-norm hidden at an ABSOLUTE sequence position, if that
179+
// position is covered by the most recent verify_batch's hidden capture.
180+
// The Qwen3.6 MTP head needs h_{base_pos-1} for its input pair at each
181+
// chain step, which equals last_hidden() only on the first chain step
182+
// (right after prefill); subsequent steps need a hidden from earlier in
183+
// the most recent verify_batch chunk. Returns nullptr if out of range.
184+
virtual const float * hidden_at_pos(int abs_pos) const {
185+
(void)abs_pos;
186+
return nullptr;
187+
}
188+
189+
// Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp
190+
// PR #22673's `t_h_pre_norm`. The Qwen3.6 MTP head's hnorm normalises
191+
// h_prev internally; feeding it the post-output-norm tensor double-
192+
// normalises and compounds per-depth rejection on D>=2 chains. Spec-
193+
// chain callers must prefer this accessor for the outer h_prev_0 seed
194+
// and fall back to hidden_at_pos() only if it returns nullptr (e.g.
195+
// adapters that do not yet capture the pre-norm sequence). Default
196+
// returns nullptr; Qwen35DFlashTarget overrides it when hidden-seq
197+
// capture is enabled.
198+
virtual const float * hidden_at_pos_pre_norm(int abs_pos) const {
199+
(void)abs_pos;
200+
return nullptr;
201+
}
75202
};
76203

77204
} // namespace dflash27b

dflash/src/common/gguf_metadata.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// common/gguf_metadata.h — Shared helpers for reading GGUF metadata.
2+
//
3+
// Provides typed "get or default" accessors and "require" accessors for
4+
// gguf_context key-value pairs, plus an architecture validation helper.
5+
// Use these in every loader; do not inline equivalent helpers per-arch.
6+
//
7+
// Include convention: #include "common/gguf_metadata.h"
8+
// Never: ../common/gguf_metadata.h or absolute paths.
9+
10+
#pragma once
11+
12+
#include "gguf.h"
13+
14+
#include <cstdint>
15+
#include <string>
16+
17+
namespace dflash::common {
18+
19+
// ── Read-with-default ─────────────────────────────────────────────────────
20+
// Return the stored value when the key is present, default_val otherwise.
21+
22+
inline uint32_t gguf_get_u32_or(struct gguf_context * gguf, const char * key, uint32_t default_val) {
23+
int64_t id = gguf_find_key(gguf, key);
24+
return (id >= 0) ? gguf_get_val_u32(gguf, id) : default_val;
25+
}
26+
27+
inline int32_t gguf_get_i32_or(struct gguf_context * gguf, const char * key, int32_t default_val) {
28+
int64_t id = gguf_find_key(gguf, key);
29+
return (id >= 0) ? gguf_get_val_i32(gguf, id) : default_val;
30+
}
31+
32+
inline float gguf_get_f32_or(struct gguf_context * gguf, const char * key, float default_val) {
33+
int64_t id = gguf_find_key(gguf, key);
34+
return (id >= 0) ? gguf_get_val_f32(gguf, id) : default_val;
35+
}
36+
37+
inline std::string gguf_get_str_or(struct gguf_context * gguf, const char * key, const std::string & default_val) {
38+
int64_t id = gguf_find_key(gguf, key);
39+
return (id >= 0) ? std::string(gguf_get_val_str(gguf, id)) : default_val;
40+
}
41+
42+
// ── Required reads ────────────────────────────────────────────────────────
43+
// Return false and write a descriptive error when the key is absent.
44+
45+
inline bool gguf_require_u32(struct gguf_context * gguf, const char * key,
46+
uint32_t & out, std::string & out_error) {
47+
int64_t id = gguf_find_key(gguf, key);
48+
if (id < 0) {
49+
out_error = std::string("missing required GGUF key: ") + key;
50+
return false;
51+
}
52+
out = gguf_get_val_u32(gguf, id);
53+
return true;
54+
}
55+
56+
inline bool gguf_require_str(struct gguf_context * gguf, const char * key,
57+
std::string & out, std::string & out_error) {
58+
int64_t id = gguf_find_key(gguf, key);
59+
if (id < 0) {
60+
out_error = std::string("missing required GGUF key: ") + key;
61+
return false;
62+
}
63+
out = gguf_get_val_str(gguf, id);
64+
return true;
65+
}
66+
67+
// ── Architecture validation ───────────────────────────────────────────────
68+
// Return true when "general.architecture" equals expected_arch.
69+
// On mismatch or absence, writes a descriptive error and returns false.
70+
71+
inline bool gguf_check_architecture(struct gguf_context * gguf,
72+
const char * expected_arch,
73+
std::string & out_error) {
74+
int64_t id = gguf_find_key(gguf, "general.architecture");
75+
if (id < 0) {
76+
out_error = "missing required GGUF key: general.architecture";
77+
return false;
78+
}
79+
const char * arch = gguf_get_val_str(gguf, id);
80+
if (std::string(arch) != expected_arch) {
81+
out_error = std::string("unexpected architecture: got '") + arch
82+
+ "', expected '" + expected_arch + "'";
83+
return false;
84+
}
85+
return true;
86+
}
87+
88+
} // namespace dflash::common

0 commit comments

Comments
 (0)