Skip to content

Commit 0f8270a

Browse files
authored
Merge pull request #306 from InfiniTensor/issue/305
issue/305 - feat: Add support for mistral model type
2 parents 4eab14d + 99396da commit 0f8270a

File tree

4 files changed

+74
-8
lines changed

4 files changed

+74
-8
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "mistral_for_causal_lm.hpp"
2+
#include "../models_registry.hpp"
3+
4+
namespace infinilm::models::mistral {
5+
6+
std::shared_ptr<infinilm::config::ModelConfig> create_mistral_model_config(std::shared_ptr<infinilm::config::ModelConfig> model_config) {
7+
const std::string &model_type = model_config->get<std::string>("model_type");
8+
if ("mistral" != model_type) {
9+
throw std::runtime_error(
10+
"infinilm::models::mistral::create_mistral_model_config: model_type is not mistral");
11+
}
12+
13+
nlohmann::json &config_json = model_config->get_config_json();
14+
15+
if (!config_json.contains("head_dim")) {
16+
size_t head_dim = model_config->get<size_t>("hidden_size")
17+
/ model_config->get<size_t>("num_attention_heads");
18+
config_json["head_dim"] = head_dim;
19+
}
20+
21+
if (!config_json.contains("attention_bias")) {
22+
config_json["attention_bias"] = false;
23+
}
24+
25+
return model_config;
26+
}
27+
28+
} // namespace infinilm::models::mistral
29+
30+
namespace {
31+
32+
INFINILM_REGISTER_CAUSAL_LM_MODEL(
33+
mistral,
34+
infinilm::models::mistral::MistralForCausalLM,
35+
infinilm::models::mistral::create_mistral_model_config);
36+
37+
} // namespace
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "../../layers/common_modules.hpp"
4+
#include <memory>
5+
6+
namespace infinilm::models::mistral {
7+
8+
using MistralMLP = infinilm::layers::MLP;
9+
10+
using MistralAttention = infinilm::layers::attention::Attention;
11+
12+
using MistralDecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer<MistralAttention, MistralMLP>;
13+
14+
using MistralModel = infinilm::layers::causal_lm_templates::TextModel<MistralDecoderLayer>;
15+
16+
using MistralForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM<MistralModel>;
17+
18+
} // namespace infinilm::models::mistral
19+
20+
namespace infinilm::models::mistral {
21+
22+
std::shared_ptr<infinilm::config::ModelConfig> create_mistral_model_config(std::shared_ptr<infinilm::config::ModelConfig> model_config);
23+
24+
} // namespace infinilm::models::mistral

examples/jiuge.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,17 @@ def test(
221221
# prompt = "山东最高的山是?"
222222
if isinstance(prompts, str):
223223
prompts = [prompts]
224-
input_contents = [
225-
tokenizer.apply_chat_template(
226-
conversation=[{"role": "user", "content": prompt}],
227-
add_generation_prompt=True,
228-
tokenize=False,
229-
)
230-
for prompt in prompts
231-
]
224+
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None:
225+
input_contents = [
226+
tokenizer.apply_chat_template(
227+
conversation=[{"role": "user", "content": prompt}],
228+
add_generation_prompt=True,
229+
tokenize=False,
230+
)
231+
for prompt in prompts
232+
]
233+
else:
234+
input_contents = prompts
232235

233236
# input_ids_list = tokenizer.batch_encode_plus(input_contents)[
234237
# "input_ids"

python/infinilm/auto_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,7 @@ def from_pretrained(model_path):
3333
return LlamaConfig(**config_dict)
3434
elif config_dict["model_type"] in ["qwen3_next" , "minicpm_sala" , "qwen3_vl" , "qwen3_moe"]:
3535
return LlamaConfig(**config_dict)
36+
elif config_dict["model_type"] == "mistral":
37+
return LlamaConfig(**config_dict)
3638

3739
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")

0 commit comments

Comments
 (0)