11#include " models.h"
22
3+ // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
4+ static ggml_tensor * ggml_view_2d_slice (ggml_context * ctx0, ggml_tensor * x, int idx) {
5+ GGML_ASSERT (idx < (int ) x->ne [2 ]);
6+ return ggml_view_2d (ctx0, x, x->ne [0 ], x->ne [1 ], ggml_row_size (x->type , x->ne [0 ]),
7+ idx * x->ne [0 ] * x->ne [1 ] * ggml_element_size (x));
8+ }
9+
310llm_build_gemma3n_iswa::llm_build_gemma3n_iswa (const llama_model & model, const llm_graph_params & params) :
411 llm_graph_context(params),
512 model(model),
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
2229 // TODO: is causal == true correct? might need some changes
2330 auto * inp_attn = build_attn_inp_kv_iswa ();
2431
25- // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
26- ggml_tensor * inp_per_layer = project_per_layer_inputs (inpL, get_per_layer_inputs ());
32+ ggml_tensor * inp_per_layer = build_inp_per_layer ();
33+ ggml_build_forward_expand (gf, inp_per_layer);
34+
35+ // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
36+ inp_per_layer = project_per_layer_inputs (inpL, inp_per_layer);
2737
2838 // inpL now has only 1 altup, project it to the rest of the altups
2939 // these "added" altups will be concat to the last dim of inpL
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
3747 inpL = ggml_concat (ctx0, inpL, altup_added, 2 ); // shape: [n_embd, n_tokens, n_altup]
3848 cb (inpL, " inp_stacked" , -1 );
3949 }
40- // inpL now has shape: [n_embd, n_tokens, n_altup]
41- // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
50+ // inpL now has shape: [n_embd, n_tokens, n_altup]
4251
4352 for (int il = 0 ; il < n_layer; ++il) {
4453 // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
4958 ggml_tensor * predictions = altup_predict (cur, il); // [n_embd, n_tokens, n_altup]
5059
5160 // predicted value will go through self-attention and laurel
52- ggml_tensor * active_prediction = view_2d_slice ( predictions, i_altup_act); // [n_embd, n_tokens]
53- cur = active_prediction;
61+ ggml_tensor * active_prediction = ggml_view_2d_slice (ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
62+ cur = active_prediction;
5463 cb (cur, " active_prediction" , il);
5564
5665 // norm
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
151160
152161 ggml_tensor * first_prediction; // [n_embd, n_tokens]
153162 {
154- first_prediction = view_2d_slice ( corrected, i_altup_act); // [n_embd, n_tokens]
163+ first_prediction = ggml_view_2d_slice (ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
155164 first_prediction = ggml_mul (ctx0, first_prediction, model.layers [il].altup_correct_scale );
156165 first_prediction = build_lora_mm (model.layers [il].per_layer_inp_gate , first_prediction);
157166 first_prediction = ggml_gelu (ctx0, first_prediction); // [n_embd_altup, n_tokens]
158167 cb (first_prediction, " first_prediction_gated" , il);
159- ggml_tensor * inp_this_layer = view_2d_slice (inp_per_layer, il); // [n_embd_altup, n_tokens]
168+
169+ ggml_tensor * inp_this_layer = ggml_view_2d_slice (ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
160170 first_prediction = ggml_mul (ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
161171 cb (first_prediction, " first_prediction_scaled" , il);
162172
@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
167177 }
168178 // equivalent to python code: corrected_predictions[1:] += first_prediction
169179 {
170- ggml_tensor * slice_first = view_2d_slice ( corrected, 0 );
180+ ggml_tensor * slice_first = ggml_view_2d_slice (ctx0, corrected, 0 );
171181 ggml_tensor * slice_rest = ggml_view_3d (
172182 ctx0, corrected, n_embd, n_tokens, n_altup - 1 , ggml_row_size (corrected->type , n_embd),
173183 ggml_row_size (corrected->type , n_embd * n_tokens), n_embd * n_tokens * ggml_element_size (corrected));
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
185195
186196 // cur now has multiple altup(s), we want to merge them back to 1 altup
187197 {
188- ggml_tensor * target_magnitude = calc_magnitude (view_2d_slice ( cur, i_altup_act)); // [n_embd, n_tokens]
198+ ggml_tensor * target_magnitude = calc_magnitude (ggml_view_2d_slice (ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
189199 // do a view to skip the first slice (active altup)
190200 ggml_tensor * alt_slice =
191201 ggml_view_3d (ctx0, cur, n_embd, n_tokens, n_altup - 1 , ggml_row_size (cur->type , n_embd),
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
197207 cb (altup_unembd, " altup_unembd" , -1 );
198208
199209 // equivalent to torch.mean(hidden_states, dim=0)
200- cur = view_2d_slice ( cur, 0 ); // [n_embd, n_tokens]
210+ cur = ggml_view_2d_slice (ctx0, cur, 0 ); // [n_embd, n_tokens]
201211 for (int i = 0 ; i < n_altup - 1 ; ++i) {
202- cur = ggml_add (ctx0, cur, view_2d_slice ( altup_unembd, i));
212+ cur = ggml_add (ctx0, cur, ggml_view_2d_slice (ctx0, altup_unembd, i));
203213 }
204214 cur = ggml_scale (ctx0, cur, 1 .0f / float (n_altup)); // [n_embd, n_tokens]
205215 cb (cur, " unembd_merged" , -1 );
@@ -235,34 +245,27 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
235245 return ggml_sqrt (ctx0, ggml_sum_rows (ctx0, ggml_sqr (ctx0, x)));
236246}
237247
238- // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
239- ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice (ggml_tensor * x, int idx) {
240- GGML_ASSERT (idx < (int ) x->ne [2 ]);
241- return ggml_view_2d (ctx0, x, x->ne [0 ], x->ne [1 ], ggml_row_size (x->type , x->ne [0 ]),
242- idx * x->ne [0 ] * x->ne [1 ] * ggml_element_size (x));
243- }
244-
245248// equivalent to get_per_layer_inputs() in python code
246249// output shape: [n_embd_altup, n_layer, n_tokens]
247- ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs () {
250+ ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer () {
248251 auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
249252 ggml_tensor * inp_per_layer;
250253 if (ubatch.token ) {
251254 inp->tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, ubatch.n_tokens );
252255 ggml_set_input (inp->tokens );
253256 res->t_inp_tokens = inp->tokens ;
254- inp_per_layer = ggml_get_rows (ctx0, model.tok_embd_per_layer , inp->tokens );
257+ inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd , inp->tokens );
255258 inp_per_layer = ggml_reshape_3d (ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
256259 inp_per_layer = ggml_scale (ctx0, inp_per_layer, sqrtf ((float ) n_embd_altup));
257260 cb (inp_per_layer, " inp_per_layer_selected" , -1 );
258261 res->add_input (std::move (inp));
259262 } else {
260263 // Vision embedding path: use padding token (ID=0) embedding
261264 // TODO: verify if this is the correct behavior in transformers implementation
262- const int64_t embd_size = model.tok_embd_per_layer ->ne [0 ]; // n_embd_altup * n_layer
265+ const int64_t embd_size = model.per_layer_tok_embd ->ne [0 ]; // n_embd_altup * n_layer
263266
264267 // Extract and dequantize padding token embedding (row 0)
265- ggml_tensor * padding = ggml_view_1d (ctx0, model.tok_embd_per_layer , embd_size, 0 );
268+ ggml_tensor * padding = ggml_view_1d (ctx0, model.per_layer_tok_embd , embd_size, 0 );
266269 inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32);
267270
268271 // Reshape to [n_embd_altup, n_layer, 1]
@@ -275,18 +278,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
275278// equivalent to project_per_layer_inputs() in python code
276279// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
277280// output shape: [n_embd_altup, n_tokens, n_layer]
278- ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs (ggml_tensor * inputs_embeds , ggml_tensor * inp_per_layer) {
281+ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs (ggml_tensor * inp_batch , ggml_tensor * inp_per_layer) {
279282 const float per_layer_projection_scale = 1 .0f / sqrtf ((float ) n_embd);
280283 const float per_layer_input_scale = 1 .0f / sqrtf (2 .0f );
281284
282- ggml_tensor * per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj , inputs_embeds);
283- per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
284- per_layer_proj = ggml_reshape_3d (ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
285- per_layer_proj = build_norm (per_layer_proj, model.per_layer_proj_norm , NULL , LLM_NORM_RMS,
286- -1 ); // [n_embd_altup, n_layer, n_tokens]
285+ ggml_tensor * per_layer_proj;
286+ per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj , inp_batch);
287+ per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
288+ per_layer_proj = ggml_reshape_3d (ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
289+
290+ per_layer_proj = build_norm (per_layer_proj, model.per_layer_proj_norm , NULL , LLM_NORM_RMS, -1 );
287291 cb (per_layer_proj, " per_layer_proj" , -1 );
288292
289- inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
293+ inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
290294 inp_per_layer = ggml_scale (ctx0, inp_per_layer, per_layer_input_scale);
291295 cb (inp_per_layer, " inp_per_layer" , -1 );
292296
@@ -337,7 +341,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
337341// input cur shape: [n_embd, n_tokens, n_altup]
338342// output shape: [n_embd, n_tokens, n_altup]
339343ggml_tensor * llm_build_gemma3n_iswa::altup_predict (ggml_tensor * cur, int il) {
340- ggml_tensor * activated = view_2d_slice ( cur, i_altup_act); // [n_embd, n_tokens]
344+ ggml_tensor * activated = ggml_view_2d_slice (ctx0, cur, i_altup_act); // [n_embd, n_tokens]
341345 ggml_tensor * modalities = altup_compute_router_modalities (activated, il); // [n_altup, n_tokens]
342346 cb (modalities, " modalities" , il);
343347
@@ -365,7 +369,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
365369 ggml_tensor * modalities = altup_compute_router_modalities (activated, il); // [n_altup, n_tokens]
366370 cb (modalities, " modalities" , il);
367371
368- ggml_tensor * active_prediction = view_2d_slice ( predictions, i_altup_act);
372+ ggml_tensor * active_prediction = ggml_view_2d_slice (ctx0, predictions, i_altup_act);
369373 ggml_tensor * innovation = ggml_sub (ctx0, activated, active_prediction); // [n_embd, n_tokens]
370374 cb (innovation, " innovation" , il);
371375
0 commit comments