@@ -880,17 +880,42 @@ static tq_model_t* tq_load_safetensors(const char* path) {
880880 model -> config .intermediate_dim = model -> config .hidden_dim * 4 ;
881881 }
882882
883+ /* Detect Gemma3 architecture by presence of pre_feedforward_layernorm */
884+ {
885+ snprintf (name_buf , sizeof (name_buf ),
886+ "model.layers.0.pre_feedforward_layernorm.weight" );
887+ tensor_info_t * gemma3_probe = find_tensor (tensors , n_tensors , name_buf );
888+ if (gemma3_probe ) {
889+ model -> config .model_type = 1 ; /* gemma3 */
890+ model -> config .n_norms_per_block = 4 ;
891+ fprintf (stderr , "tq_load_model: detected Gemma3 architecture (4 norms per block)\n" );
892+ } else {
893+ model -> config .model_type = 0 ; /* qwen35 */
894+ model -> config .n_norms_per_block = 2 ;
895+ }
896+ }
897+
883898 /* Defaults — tuned for Qwen3.5 if DeltaNet detected */
884899 model -> config .max_seq_len = 4096 ;
885- if (model -> config .delta_n_heads > 0 ) {
900+ if (model -> config .model_type == 1 ) {
901+ /* Gemma3: rope_theta=1M for global, 10K for local, rms_norm_eps=1e-6 */
902+ model -> config .rope_freq_base = 1000000.0f ; /* global layers */
903+ model -> config .rope_local_base_freq = 10000.0f ; /* sliding/local layers */
904+ model -> config .rms_norm_eps = 1e-6f ;
905+ model -> config .partial_rotary_factor = 0.0f ;
906+ model -> config .sliding_window = 512 ;
907+ model -> config .query_pre_attn_scalar = 256.0f ;
908+ } else if (model -> config .delta_n_heads > 0 ) {
886909 /* Qwen3.5 uses rope_theta=10M, rms_norm_eps=1e-6, partial_rotary=0.25 */
887910 model -> config .rope_freq_base = 10000000.0f ;
888911 model -> config .rms_norm_eps = 1e-6f ;
889912 model -> config .partial_rotary_factor = 0.25f ;
913+ model -> config .query_pre_attn_scalar = 0.0f ;
890914 } else {
891915 model -> config .rope_freq_base = 10000.0f ;
892916 model -> config .rms_norm_eps = 1e-5f ;
893917 model -> config .partial_rotary_factor = 0.0f ;
918+ model -> config .query_pre_attn_scalar = 0.0f ;
894919 }
895920
896921 /* Allocate layer weight pointers */
@@ -917,13 +942,32 @@ static tq_model_t* tq_load_safetensors(const char* path) {
917942 find_tensor (tensors , n_tensors , name_buf ),
918943 & conv_buf , & conv_used , conv_capacity );
919944
920- /* FFN norm */
945+ /* FFN norm (Qwen3.5: post_attention_layernorm used as pre-FFN norm) */
921946 snprintf (name_buf , sizeof (name_buf ),
922947 "model.layers.%d.post_attention_layernorm.weight" , l );
923948 layer -> ffn_norm = load_tensor (data_base ,
924949 find_tensor (tensors , n_tensors , name_buf ),
925950 & conv_buf , & conv_used , conv_capacity );
926951
952+ /* Gemma3 extra norms: post_attn, pre_ffn, post_ffn */
953+ if (model -> config .model_type == 1 ) {
954+ /* For Gemma3, post_attention_layernorm is applied to attn output,
955+ * not as pre-FFN norm. Store it in post_attn_norm. */
956+ layer -> post_attn_norm = layer -> ffn_norm ;
957+
958+ snprintf (name_buf , sizeof (name_buf ),
959+ "model.layers.%d.pre_feedforward_layernorm.weight" , l );
960+ layer -> pre_ffn_norm = load_tensor (data_base ,
961+ find_tensor (tensors , n_tensors , name_buf ),
962+ & conv_buf , & conv_used , conv_capacity );
963+
964+ snprintf (name_buf , sizeof (name_buf ),
965+ "model.layers.%d.post_feedforward_layernorm.weight" , l );
966+ layer -> post_ffn_norm = load_tensor (data_base ,
967+ find_tensor (tensors , n_tensors , name_buf ),
968+ & conv_buf , & conv_used , conv_capacity );
969+ }
970+
927971 /* Q, K, V, O projections — only exist for self_attn layers */
928972 snprintf (name_buf , sizeof (name_buf ),
929973 "model.layers.%d.self_attn.q_proj.weight" , l );
@@ -1107,6 +1151,77 @@ static tq_model_t* tq_load_safetensors(const char* path) {
11071151 fprintf (stderr , "tq_load_model: applied Qwen3.5 RMSNorm +1 weight adjustment\n" );
11081152 }
11091153
1154+ /* Gemma3 RMSNorm adjustment: same (1+w) scaling as Qwen3.5 */
1155+ if (model -> config .model_type == 1 ) {
1156+ int dim_h = model -> config .hidden_dim ;
1157+ int head_dim_h = model -> config .head_dim ;
1158+
1159+ for (int l = 0 ; l < n_layers ; l ++ ) {
1160+ tq_layer_weights_t * layer_w = & model -> layers [l ];
1161+ if (layer_w -> attn_norm ) {
1162+ for (int i = 0 ; i < dim_h ; i ++ ) {
1163+ layer_w -> attn_norm [i ] += 1.0f ;
1164+ }
1165+ }
1166+ if (layer_w -> post_attn_norm ) {
1167+ for (int i = 0 ; i < dim_h ; i ++ ) {
1168+ layer_w -> post_attn_norm [i ] += 1.0f ;
1169+ }
1170+ }
1171+ if (layer_w -> pre_ffn_norm ) {
1172+ for (int i = 0 ; i < dim_h ; i ++ ) {
1173+ layer_w -> pre_ffn_norm [i ] += 1.0f ;
1174+ }
1175+ }
1176+ if (layer_w -> post_ffn_norm ) {
1177+ for (int i = 0 ; i < dim_h ; i ++ ) {
1178+ layer_w -> post_ffn_norm [i ] += 1.0f ;
1179+ }
1180+ }
1181+ if (layer_w -> q_norm ) {
1182+ for (int i = 0 ; i < head_dim_h ; i ++ ) {
1183+ layer_w -> q_norm [i ] += 1.0f ;
1184+ }
1185+ }
1186+ if (layer_w -> k_norm ) {
1187+ for (int i = 0 ; i < head_dim_h ; i ++ ) {
1188+ layer_w -> k_norm [i ] += 1.0f ;
1189+ }
1190+ }
1191+ }
1192+ if (model -> output_norm ) {
1193+ for (int i = 0 ; i < dim_h ; i ++ ) {
1194+ model -> output_norm [i ] += 1.0f ;
1195+ }
1196+ }
1197+ fprintf (stderr , "tq_load_model: applied Gemma3 RMSNorm +1 weight adjustment\n" );
1198+
1199+ /* Set up layer_is_sliding for Gemma3.
1200+ * Pattern: 5 sliding + 1 full, repeated. Layers 0-4=sliding, 5=full, etc.
1201+ * We detect by checking layer count modulo 6. */
1202+ model -> layer_is_sliding = (int * )calloc ((size_t )n_layers , sizeof (int ));
1203+ if (model -> layer_is_sliding ) {
1204+ for (int l = 0 ; l < n_layers ; l ++ ) {
1205+ /* Full/global attention every 6th layer (indices 5, 11, 17, ...) */
1206+ if ((l + 1 ) % 6 == 0 ) {
1207+ model -> layer_is_sliding [l ] = 0 ; /* global */
1208+ } else {
1209+ model -> layer_is_sliding [l ] = 1 ; /* sliding */
1210+ }
1211+ }
1212+ int n_sliding = 0 , n_global = 0 ;
1213+ for (int l = 0 ; l < n_layers ; l ++ ) {
1214+ if (model -> layer_is_sliding [l ]) {
1215+ n_sliding ++ ;
1216+ } else {
1217+ n_global ++ ;
1218+ }
1219+ }
1220+ fprintf (stderr , "tq_load_model: Gemma3 layer types: %d sliding, %d global\n" ,
1221+ n_sliding , n_global );
1222+ }
1223+ }
1224+
11101225 fprintf (stderr , "tq_load_model: loaded %d layers (%d with self_attn), "
11111226 "dim=%d, heads=%d/%d, vocab=%d\n" ,
11121227 model -> config .n_layers , model -> n_attn_layers ,
@@ -1679,6 +1794,13 @@ tq_model_t* tq_load_tqm(const char* path) {
16791794 c -> use_qk_norm = hdr -> use_qk_norm ;
16801795 c -> attn_output_gate = hdr -> attn_output_gate ;
16811796
1797+ /* Multi-architecture fields */
1798+ c -> model_type = hdr -> model_type ;
1799+ c -> sliding_window = hdr -> sliding_window ;
1800+ c -> rope_local_base_freq = hdr -> rope_local_base_freq ;
1801+ c -> n_norms_per_block = hdr -> n_norms_per_block ;
1802+ c -> query_pre_attn_scalar = hdr -> query_pre_attn_scalar ;
1803+
16821804 /* Attn layer indices */
16831805 model -> n_attn_layers = hdr -> n_attn_layers ;
16841806 if (hdr -> n_attn_layers > 0 ) {
@@ -1748,6 +1870,13 @@ tq_model_t* tq_load_tqm(const char* path) {
17481870 TQM_READ_FP32 (layer -> attn_norm , dim );
17491871 TQM_READ_FP32 (layer -> ffn_norm , dim );
17501872
1873+ /* Gemma3 extra norms */
1874+ if (c -> model_type == 1 ) {
1875+ layer -> post_attn_norm = layer -> ffn_norm ; /* shares storage */
1876+ TQM_READ_FP32 (layer -> pre_ffn_norm , dim );
1877+ TQM_READ_FP32 (layer -> post_ffn_norm , dim );
1878+ }
1879+
17511880 if (is_attn_layer && is_attn_layer [l ]) {
17521881 /* Self-attention layer */
17531882 TQM_READ_Q4 (layer -> wq_q4 , layer -> wq_q4s , qg_dim , dim );
@@ -1814,6 +1943,20 @@ tq_model_t* tq_load_tqm(const char* path) {
18141943 model -> use_q4_weights = 1 ;
18151944 free (is_attn_layer );
18161945
1946+ /* Set up Gemma3 layer_is_sliding from TQM */
1947+ if (c -> model_type == 1 && c -> sliding_window > 0 ) {
1948+ model -> layer_is_sliding = (int * )calloc ((size_t )c -> n_layers , sizeof (int ));
1949+ if (model -> layer_is_sliding ) {
1950+ for (int l = 0 ; l < c -> n_layers ; l ++ ) {
1951+ if ((l + 1 ) % 6 == 0 ) {
1952+ model -> layer_is_sliding [l ] = 0 ; /* global */
1953+ } else {
1954+ model -> layer_is_sliding [l ] = 1 ; /* sliding */
1955+ }
1956+ }
1957+ }
1958+ }
1959+
18171960 /* Runtime Q4 quantization of lm_head (output projection) for fast logit computation.
18181961 * BF16 matmul on 248K x 1024 is slow; Q4 matmul is ~4x faster. */
18191962 if (model -> output_weight_bf16 ) {
@@ -1982,6 +2125,12 @@ int tq_save_tqm(tq_model_t* model, const char* tokenizer_path,
19822125 hdr .use_qk_norm = c -> use_qk_norm ;
19832126 hdr .attn_output_gate = c -> attn_output_gate ;
19842127
2128+ hdr .model_type = c -> model_type ;
2129+ hdr .sliding_window = c -> sliding_window ;
2130+ hdr .rope_local_base_freq = c -> rope_local_base_freq ;
2131+ hdr .n_norms_per_block = c -> n_norms_per_block ;
2132+ hdr .query_pre_attn_scalar = c -> query_pre_attn_scalar ;
2133+
19852134 hdr .weight_quant = 4 ; /* Q4 */
19862135 hdr .embed_format = 16 ; /* BF16 */
19872136
@@ -2041,6 +2190,12 @@ int tq_save_tqm(tq_model_t* model, const char* tokenizer_path,
20412190 TQM_WRITE_FP32 (layer -> attn_norm , dim );
20422191 TQM_WRITE_FP32 (layer -> ffn_norm , dim );
20432192
2193+ /* Gemma3 extra norms */
2194+ if (c -> model_type == 1 ) {
2195+ TQM_WRITE_FP32 (layer -> pre_ffn_norm , dim );
2196+ TQM_WRITE_FP32 (layer -> post_ffn_norm , dim );
2197+ }
2198+
20442199 if (is_attn_layer [l ]) {
20452200 TQM_WRITE_Q4 (layer -> wq_q4 , layer -> wq_q4s , qg_dim , dim );
20462201 TQM_WRITE_Q4 (layer -> wk_q4 , layer -> wk_q4s , kv_dim , dim );
@@ -2144,6 +2299,7 @@ void tq_free_model(tq_model_t* model) {
21442299 free (model -> _q8_data );
21452300 free (model -> _q4_data );
21462301 free (model -> attn_layer_indices );
2302+ free (model -> layer_is_sliding );
21472303 free (model -> layers );
21482304 free (model );
21492305}
0 commit comments