Skip to content

Commit a9039ad

Browse files
authored
feat: support parsing dtype field from config. (#1313)
1 parent f73e410 commit a9039ad

1 file changed

Lines changed: 22 additions & 6 deletions

File tree

xllm/core/framework/hf_model_loader.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ namespace xllm {
5555

5656
namespace {
5757

58+
JsonReader normalize_config_torch_dtype(const JsonReader& reader) {
59+
auto config = reader.data();
60+
if (!config.contains("torch_dtype") && config.contains("dtype")) {
61+
config["torch_dtype"] = config["dtype"];
62+
}
63+
64+
JsonReader normalized_reader;
65+
normalized_reader.parse_text(config.dump());
66+
return normalized_reader;
67+
}
68+
5869
bool is_compressed_tensors_fp8_scheme(const nlohmann::json& config) {
5970
auto type_it = config.find("type");
6071
auto num_bits_it = config.find("num_bits");
@@ -752,7 +763,8 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
752763
<< resolved_model_type;
753764
return false;
754765
}
755-
model_args_loader(reader, &args_);
766+
const JsonReader config_reader = normalize_config_torch_dtype(reader);
767+
model_args_loader(config_reader, &args_);
756768

757769
return true;
758770
}
@@ -765,16 +777,20 @@ bool HFModelLoader::load_quant_args(const std::string& model_weights_path) {
765777
return false;
766778
}
767779

768-
if (!load_quant_cfg(reader, quant_args_)) {
780+
const JsonReader config_reader = normalize_config_torch_dtype(reader);
781+
782+
if (!load_quant_cfg(config_reader, quant_args_)) {
769783
return false;
770784
}
771785

772786
// load quantization args for npu if exists
773-
if (reader.contains("quantize")) {
774-
quant_args_.quantize_type() = reader.value_or<std::string>("quantize", "");
787+
if (config_reader.contains("quantize")) {
788+
quant_args_.quantize_type() =
789+
config_reader.value_or<std::string>("quantize", "");
775790
}
776-
if (reader.contains("torch_dtype")) {
777-
quant_args_.torch_dtype() = reader.value_or<std::string>("torch_dtype", "");
791+
if (config_reader.contains("torch_dtype")) {
792+
quant_args_.torch_dtype() =
793+
config_reader.value_or<std::string>("torch_dtype", "");
778794
}
779795

780796
// awq quantization args

0 commit comments

Comments
 (0)