@@ -318,6 +318,54 @@ TEST_F(DenseMLPTest, SmoothquantLoadStateDictTest) {
318318 LOG (INFO) << " State dict loading test passed - output sum: " << output_sum;
319319}
320320
321+ TEST_F (DenseMLPTest, Fp8IgnoredDownProjLoadsAsUnquantized) {
322+ QuantArgs fp8_quant_args;
323+ fp8_quant_args.quant_method () = kQuantMethodFp8 ;
324+ fp8_quant_args.bits () = 8 ;
325+ fp8_quant_args.activation_dynamic () = false ;
326+ fp8_quant_args.ignored_modules () = {" model.layers.1.mlp.down_proj" };
327+
328+ const int64_t hidden_size = 16 ;
329+ const int64_t intermediate_size = 32 ;
330+ auto mlp = DenseMLP (DenseMLPImpl (hidden_size,
331+ intermediate_size,
332+ /* is_gated=*/ true ,
333+ /* has_bias=*/ false ,
334+ /* hidden_act=*/ " silu" ,
335+ /* enable_result_reduction=*/ true ,
336+ fp8_quant_args,
337+ parallel_args_.tp_group_ ,
338+ options_,
339+ " model.layers.1.mlp" ));
340+
341+ std::unordered_map<std::string, torch::Tensor> weight_dict;
342+ auto fp8_weight_options = options_.dtype (torch::kFloat8_e4m3fn );
343+ auto scale_options = options_.dtype (torch::kFloat32 );
344+
345+ weight_dict[" gate_proj.weight" ] =
346+ torch::zeros ({intermediate_size, hidden_size}, fp8_weight_options);
347+ weight_dict[" gate_proj.weight_scale" ] = torch::ones ({1 }, scale_options);
348+ weight_dict[" gate_proj.input_scale" ] = torch::ones ({1 }, scale_options);
349+
350+ weight_dict[" up_proj.weight" ] =
351+ torch::zeros ({intermediate_size, hidden_size}, fp8_weight_options);
352+ weight_dict[" up_proj.weight_scale" ] = torch::ones ({1 }, scale_options);
353+ weight_dict[" up_proj.input_scale" ] = torch::ones ({1 }, scale_options);
354+
355+ weight_dict[" down_proj.weight" ] =
356+ torch::zeros ({hidden_size, intermediate_size}, options_);
357+
358+ StateDict state_dict (weight_dict);
359+ mlp->load_state_dict (state_dict);
360+
361+ const auto params = mlp->named_parameters (/* recurse=*/ true );
362+ EXPECT_TRUE (params.contains (" gate_up_proj.weight_scale" ));
363+ EXPECT_TRUE (params.contains (" gate_up_proj.input_scale" ));
364+ EXPECT_TRUE (params.contains (" down_proj.weight" ));
365+ EXPECT_FALSE (params.contains (" down_proj.weight_scale" ));
366+ EXPECT_FALSE (params.contains (" down_proj.input_scale" ));
367+ }
368+
321369TEST_F (DenseMLPTest, SmoothquantPrecisionVerificationTest) {
322370 // Test precision verification with custom input and expected output
323371 const int64_t batch_size = 16 ;
0 commit comments