Skip to content

Commit 8fc8b82

Browse files
authored
gemma : perform per-layer projections in the first layer (ggml-org#21612)
* gemma : reduce graph splits by keeping per-layer ops in the input layer * gemma : put the per-layer proj in the first layer * cont : move the projection before the layer loop
1 parent cda88c3 commit 8fc8b82

6 files changed

Lines changed: 108 additions & 91 deletions

File tree

src/llama-arch.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -558,20 +558,20 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
558558
// example: https://github.com/ggml-org/llama.cpp/pull/17548
559559
//
560560
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
561-
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
562-
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
563-
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
561+
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
562+
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
563+
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
564564
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
565-
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
566-
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
567-
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
568-
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
569-
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
570-
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
571-
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
572-
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
573-
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
574-
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
565+
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
566+
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
567+
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
568+
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
569+
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
570+
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
571+
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
572+
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
573+
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
574+
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
575575
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
576576
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
577577
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
@@ -708,9 +708,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
708708
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
709709
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
710710
// altup / laurel (gemma 3n)
711-
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
712-
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
713-
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
711+
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
712+
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
713+
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
714714
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
715715
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
716716
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-model.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4211,13 +4211,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
42114211
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
42124212
}
42134213

4214-
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4215-
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
4214+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4215+
4216+
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
4217+
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
42164218

4217-
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
4218-
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
4219-
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
4220-
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
4219+
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
4220+
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
4221+
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0);
42214222

42224223
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
42234224

@@ -4276,9 +4277,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
42764277
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
42774278

42784279
if (n_embd_per_layer > 0) {
4279-
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
4280-
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
4281-
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
4280+
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
4281+
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0);
4282+
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0);
42824283
}
42834284

42844285
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);

src/llama-model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,9 @@ struct llama_model {
534534
struct ggml_tensor * conv1d_b = nullptr;
535535

536536
// gemma3n altup
537-
struct ggml_tensor * tok_embd_per_layer = nullptr;
538537
struct ggml_tensor * altup_proj = nullptr;
539538
struct ggml_tensor * altup_unembd_proj = nullptr;
539+
struct ggml_tensor * per_layer_tok_embd = nullptr;
540540
struct ggml_tensor * per_layer_model_proj = nullptr;
541541
struct ggml_tensor * per_layer_proj_norm = nullptr;
542542

src/models/gemma3n-iswa.cpp

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
#include "models.h"
22

3+
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
4+
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
5+
GGML_ASSERT(idx < (int) x->ne[2]);
6+
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
7+
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
8+
}
9+
310
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
411
llm_graph_context(params),
512
model(model),
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
2229
// TODO: is causal == true correct? might need some changes
2330
auto * inp_attn = build_attn_inp_kv_iswa();
2431

25-
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
26-
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
32+
ggml_tensor * inp_per_layer = build_inp_per_layer();
33+
ggml_build_forward_expand(gf, inp_per_layer);
34+
35+
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
36+
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
2737

2838
// inpL now has only 1 altup, project it to the rest of the altups
2939
// these "added" altups will be concat to the last dim of inpL
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
3747
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
3848
cb(inpL, "inp_stacked", -1);
3949
}
40-
// inpL now has shape: [n_embd, n_tokens, n_altup]
41-
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
50+
// inpL now has shape: [n_embd, n_tokens, n_altup]
4251

4352
for (int il = 0; il < n_layer; ++il) {
4453
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
4958
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
5059

5160
// predicted value will go through self-attention and laurel
52-
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
53-
cur = active_prediction;
61+
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
62+
cur = active_prediction;
5463
cb(cur, "active_prediction", il);
5564

5665
// norm
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
151160

152161
ggml_tensor * first_prediction; // [n_embd, n_tokens]
153162
{
154-
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
163+
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
155164
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
156165
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
157166
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
158167
cb(first_prediction, "first_prediction_gated", il);
159-
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
168+
169+
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
160170
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
161171
cb(first_prediction, "first_prediction_scaled", il);
162172

@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
167177
}
168178
// equivalent to python code: corrected_predictions[1:] += first_prediction
169179
{
170-
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
180+
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
171181
ggml_tensor * slice_rest = ggml_view_3d(
172182
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
173183
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
185195

186196
// cur now has multiple altup(s), we want to merge them back to 1 altup
187197
{
188-
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
198+
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
189199
// do a view to skip the first slice (active altup)
190200
ggml_tensor * alt_slice =
191201
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
197207
cb(altup_unembd, "altup_unembd", -1);
198208

199209
// equivalent to torch.mean(hidden_states, dim=0)
200-
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
210+
cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens]
201211
for (int i = 0; i < n_altup - 1; ++i) {
202-
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
212+
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
203213
}
204214
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
205215
cb(cur, "unembd_merged", -1);
@@ -235,34 +245,27 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
235245
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
236246
}
237247

238-
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
239-
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
240-
GGML_ASSERT(idx < (int) x->ne[2]);
241-
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
242-
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
243-
}
244-
245248
// equivalent to get_per_layer_inputs() in python code
246249
// output shape: [n_embd_altup, n_layer, n_tokens]
247-
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
250+
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
248251
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
249252
ggml_tensor * inp_per_layer;
250253
if (ubatch.token) {
251254
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
252255
ggml_set_input(inp->tokens);
253256
res->t_inp_tokens = inp->tokens;
254-
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
257+
inp_per_layer = ggml_get_rows(ctx0, model.per_layer_tok_embd, inp->tokens);
255258
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
256259
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
257260
cb(inp_per_layer, "inp_per_layer_selected", -1);
258261
res->add_input(std::move(inp));
259262
} else {
260263
// Vision embedding path: use padding token (ID=0) embedding
261264
// TODO: verify if this is the correct behavior in transformers implementation
262-
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
265+
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer
263266

264267
// Extract and dequantize padding token embedding (row 0)
265-
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
268+
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
266269
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
267270

268271
// Reshape to [n_embd_altup, n_layer, 1]
@@ -275,18 +278,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
275278
// equivalent to project_per_layer_inputs() in python code
276279
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
277280
// output shape: [n_embd_altup, n_tokens, n_layer]
278-
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
281+
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
279282
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
280283
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
281284

282-
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
283-
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
284-
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
285-
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
286-
-1); // [n_embd_altup, n_layer, n_tokens]
285+
ggml_tensor * per_layer_proj;
286+
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
287+
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
288+
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
289+
290+
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
287291
cb(per_layer_proj, "per_layer_proj", -1);
288292

289-
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
293+
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
290294
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
291295
cb(inp_per_layer, "inp_per_layer", -1);
292296

@@ -337,7 +341,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
337341
// input cur shape: [n_embd, n_tokens, n_altup]
338342
// output shape: [n_embd, n_tokens, n_altup]
339343
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
340-
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
344+
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens]
341345
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
342346
cb(modalities, "modalities", il);
343347

@@ -365,7 +369,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
365369
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
366370
cb(modalities, "modalities", il);
367371

368-
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
372+
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
369373
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
370374
cb(innovation, "innovation", il);
371375

0 commit comments

Comments
 (0)