Skip to content

Commit fb9264c

Browse files
committed
ggml-ve graph compiler: make the prompt-eval source size-independent
The N>1 codegen still baked the token count into the generated source in two spots, so each distinct prompt length recompiled (~13s NCC) instead of reusing one cached .so: - the per-op debug comment embedded `n=`/`NB=` (the literal token count); - the SWIGLU codegen emitted `int nc, nr;` with nr = ne[1]*ne[2]*ne[3] (= N), left over (and (void)'d) after the loop bound moved to elem_n. Both are N-dependent only — the actual computation already used per-token constants + the runtime n_tok arg. Dropped them. Now two different prompt lengths produce a byte-identical source (verified: 2nd length = 0 compiles), so one .so serves any prompt length, matching decode's size-independence. Output still token-for-token identical to the interpreter.
1 parent 3b176b1 commit fb9264c

1 file changed

Lines changed: 13 additions & 15 deletions

File tree

ggml/src/ggml-ve/graph_compiler.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -726,8 +726,12 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
726726
const std::string elem_n = scales_n ? (std::to_string(pt) + "LL * n_tok")
727727
: (std::to_string(full) + "LL");
728728

729+
// NOTE: keep this comment free of N-dependent values (n, NB) — the JIT
730+
// cache key is the source hash, so baking the token count here would force
731+
// a separate compile per prompt length. pt is per-token (N-independent).
732+
(void) n;
729733
ss << " // op " << idx << ": " << op_type_name(op.type)
730-
<< " '" << op.name << "' n=" << n << " pt=" << pt << " NB=" << NB << "\n";
734+
<< " '" << op.name << "' pt=" << pt << "\n";
731735

732736
switch (op.type) {
733737
case OpType::GET_ROWS: {
@@ -1209,25 +1213,19 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
12091213
}
12101214

12111215
case OpType::GLU_SWIGLU: {
1212-
// SWIGLU: y = silu(gate) * up. gate/up/dst are F32 with the
1213-
// same shape; row count is ggml_nrows(gate) = ne[1]*ne[2]*ne[3].
1214-
int64_t nc = op.ne[0];
1215-
int64_t nr = (op.ne[1] > 0 ? op.ne[1] : 1) *
1216-
(op.ne[2] > 0 ? op.ne[2] : 1) *
1217-
(op.ne[3] > 0 ? op.ne[3] : 1);
1216+
// SWIGLU: y = silu(gate) * up. gate/up/dst are F32 with the same
1217+
// shape. We parallelise the flat element range (= per-token ffn dim
1218+
// * n_tok) with one `#pragma omp for` (all 8 threads share it; the
1219+
// external swiglu_hbm_full_inner split over rows, leaving 7 threads
1220+
// idle at decode where there's one row). NOTE: emit no N-dependent
1221+
// size here — elem_n is per-token*n_tok, keeping the source
1222+
// size-independent. Clamp before expf — the VE's vectorised expf
1223+
// returns NaN past |x|~88 (see CLAUDE.md).
12181224
ss << " {\n";
1219-
ss << " int nc = " << nc << ", nr = " << nr << ";\n";
12201225
ss << " float* y = (float*)" << dst << ";\n";
12211226
ss << " float* gate = (float*)" << src0 << ";\n";
12221227
ss << " float* up = (float*)" << src1 << ";\n";
1223-
// Inline SWIGLU over the flat element range. gate/up/y are
1224-
// contiguous row-major [nc,nr], so we parallelise the whole
1225-
// nc*nr range with one `#pragma omp for` (all 8 threads share it;
1226-
// the external swiglu_hbm_full_inner only split over nr, leaving
1227-
// 7 threads idle at decode where nr==1). Clamp before expf — the
1228-
// VE's vectorised expf returns NaN past |x|~88 (see CLAUDE.md).
12291228
ss << " long total = (long)(" << elem_n << ");\n";
1230-
ss << " (void)nc; (void)nr;\n";
12311229
ss << " #pragma omp for\n";
12321230
ss << " for (long i = 0; i < total; i++) {\n";
12331231
ss << " float g = gate[i];\n";

0 commit comments

Comments
 (0)