@@ -1421,6 +1421,110 @@ def set_decoder(self, spec, module):
14211421 gc .collect ()
14221422
14231423
1424+ @register_loader ("Gemma2Config" )
1425+ class Gemma2Loader (ModelLoader ):
1426+ @property
1427+ def architecture_name (self ):
1428+ return "Gemma2ForCausalLM"
1429+
1430+ def get_model_spec (self , model ):
1431+ num_layers = model .config .num_hidden_layers
1432+
1433+ num_heads = model .config .num_attention_heads
1434+ num_heads_kv = getattr (model .config , "num_key_value_heads" , num_heads )
1435+ if num_heads_kv == num_heads :
1436+ num_heads_kv = None
1437+
1438+ activation_config = getattr (
1439+ model .config , "hidden_activation" , "gelu_pytorch_tanh"
1440+ )
1441+
1442+ spec = transformer_spec .TransformerDecoderModelSpec .from_config (
1443+ num_layers ,
1444+ num_heads ,
1445+ activation = (
1446+ common_spec .Activation .GELU
1447+ if activation_config == "gelu"
1448+ else common_spec .Activation .GELUTanh
1449+ ),
1450+ pre_norm = True ,
1451+ ffn_glu = True ,
1452+ rms_norm = True ,
1453+ rotary_dim = 0 ,
1454+ rotary_interleave = False ,
1455+ rotary_base = getattr (model .config , "rope_theta" , 10000 ),
1456+ num_heads_kv = num_heads_kv ,
1457+ head_dim = model .config .head_dim ,
1458+ pre_post_layer_norm = True ,
1459+ )
1460+
1461+ self .set_decoder (spec .decoder , model .model )
1462+ self .set_linear (spec .decoder .projection , model .lm_head )
1463+ spec .decoder .embeddings .multiply_by_sqrt_depth = model .config .hidden_size ** 0.5
1464+ return spec
1465+
1466+ def get_vocabulary (self , model , tokenizer ):
1467+ tokens = super ().get_vocabulary (model , tokenizer )
1468+
1469+ extra_ids = model .config .vocab_size - len (tokens )
1470+ for i in range (extra_ids ):
1471+ tokens .append ("<extra_id_%d>" % i )
1472+ if model .config .vocab_size < len (tokens ):
1473+ tokens = tokens [: model .config .vocab_size ]
1474+
1475+ return tokens
1476+
1477+ def set_vocabulary (self , spec , tokens ):
1478+ spec .register_vocabulary (tokens )
1479+
1480+ def set_config (self , config , model , tokenizer ):
1481+ config .bos_token = tokenizer .bos_token
1482+ config .eos_token = tokenizer .eos_token
1483+ config .unk_token = tokenizer .unk_token
1484+ config .layer_norm_epsilon = model .config .rms_norm_eps
1485+
1486+ def set_layer_norm (self , spec , layer_norm ):
1487+ spec .gamma = layer_norm .weight
1488+ spec .layer_norm_use_residual = True
1489+
1490+ def set_decoder (self , spec , module ):
1491+ spec .scale_embeddings = True
1492+ spec .start_from_zero_embedding = False
1493+ self .set_embeddings (spec .embeddings , module .embed_tokens )
1494+ self .set_layer_norm (spec .layer_norm , module .norm )
1495+
1496+ for layer_spec , layer in zip (spec .layer , module .layers ):
1497+ self .set_layer_norm (layer_spec .input_layer_norm , layer .input_layernorm )
1498+
1499+ self .set_layer_norm (
1500+ layer_spec .post_attention_layer_norm , layer .post_attention_layernorm
1501+ )
1502+
1503+ self .set_layer_norm (
1504+ layer_spec .pre_feedforward_layer_norm , layer .pre_feedforward_layernorm
1505+ )
1506+
1507+ self .set_layer_norm (
1508+ layer_spec .post_feedforward_layer_norm , layer .post_feedforward_layernorm
1509+ )
1510+
1511+ wq = layer .self_attn .q_proj .weight
1512+ wk = layer .self_attn .k_proj .weight
1513+ wv = layer .self_attn .v_proj .weight
1514+ wo = layer .self_attn .o_proj .weight
1515+
1516+ layer_spec .self_attention .linear [0 ].weight = torch .cat ([wq , wk , wv ])
1517+ layer_spec .self_attention .linear [1 ].weight = wo
1518+
1519+ self .set_linear (layer_spec .ffn .linear_0 , layer .mlp .gate_proj )
1520+ self .set_linear (layer_spec .ffn .linear_0_noact , layer .mlp .up_proj )
1521+ self .set_linear (layer_spec .ffn .linear_1 , layer .mlp .down_proj )
1522+
1523+ delattr (layer , "self_attn" )
1524+ delattr (layer , "mlp" )
1525+ gc .collect ()
1526+
1527+
14241528@register_loader ("LlamaConfig" )
14251529class LlamaLoader (ModelLoader ):
14261530 @property
0 commit comments