@@ -55,6 +55,17 @@ namespace xllm {
5555
5656namespace {
5757
58+ JsonReader normalize_config_torch_dtype (const JsonReader& reader) {
59+ auto config = reader.data ();
60+ if (!config.contains (" torch_dtype" ) && config.contains (" dtype" )) {
61+ config[" torch_dtype" ] = config[" dtype" ];
62+ }
63+
64+ JsonReader normalized_reader;
65+ normalized_reader.parse_text (config.dump ());
66+ return normalized_reader;
67+ }
68+
5869bool is_compressed_tensors_fp8_scheme (const nlohmann::json& config) {
5970 auto type_it = config.find (" type" );
6071 auto num_bits_it = config.find (" num_bits" );
@@ -752,7 +763,8 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
752763 << resolved_model_type;
753764 return false ;
754765 }
755- model_args_loader (reader, &args_);
766+ const JsonReader config_reader = normalize_config_torch_dtype (reader);
767+ model_args_loader (config_reader, &args_);
756768
757769 return true ;
758770}
@@ -765,16 +777,20 @@ bool HFModelLoader::load_quant_args(const std::string& model_weights_path) {
765777 return false ;
766778 }
767779
768- if (!load_quant_cfg (reader, quant_args_)) {
780+ const JsonReader config_reader = normalize_config_torch_dtype (reader);
781+
782+ if (!load_quant_cfg (config_reader, quant_args_)) {
769783 return false ;
770784 }
771785
772786 // load quantization args for npu if exists
773- if (reader.contains (" quantize" )) {
774- quant_args_.quantize_type () = reader.value_or <std::string>(" quantize" , " " );
787+ if (config_reader.contains (" quantize" )) {
788+ quant_args_.quantize_type () =
789+ config_reader.value_or <std::string>(" quantize" , " " );
775790 }
776- if (reader.contains (" torch_dtype" )) {
777- quant_args_.torch_dtype () = reader.value_or <std::string>(" torch_dtype" , " " );
791+ if (config_reader.contains (" torch_dtype" )) {
792+ quant_args_.torch_dtype () =
793+ config_reader.value_or <std::string>(" torch_dtype" , " " );
778794 }
779795
780796 // awq quantization args
0 commit comments