Skip to content

Commit cfe4b1a

Browse files
authored
Merge pull request #267 from InfiniTensor/issue/263_T2-1-4
【比赛2025秋】T2-1-4 qwen3vl
2 parents 66bfd28 + b1f6af3 commit cfe4b1a

File tree

14 files changed

+3133
-2
lines changed

14 files changed

+3133
-2
lines changed

include/infinicore_infer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#include "infinicore_infer/cache.h"
55
#include "infinicore_infer/weights_loader.h"
66

7+
78
#include "infinicore_infer/models/deepseek.h"
89
#include "infinicore_infer/models/jiuge.h"
10+
#include "infinicore_infer/models/jiuge_awq.h"
11+
#include "infinicore_infer/models/qwen3vl.h"
12+
913

1014
#endif /* INFINICORE_INFER_H */
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#ifndef QWEN3VL_WEIGHTS_H
2+
#define QWEN3VL_WEIGHTS_H
3+
4+
#include <infiniccl.h>
5+
#include <infiniop.h>
6+
#include <infinirt.h>
7+
8+
#include <stddef.h>
9+
#include <stdint.h>
10+
11+
struct Qwen3vlWeights;
12+
13+
// Function pointer signatures
14+
typedef void (*qwen3vl_load_global_fn)(Qwen3vlWeights *, void *cpu_ptr);
15+
typedef void (*qwen3vl_load_layer_fn)(Qwen3vlWeights *, void *cpu_ptr, size_t layer_id);
16+
// Struct containing all weight loading functions
17+
typedef struct {
18+
// Global
19+
qwen3vl_load_global_fn load_input_embd;
20+
qwen3vl_load_global_fn load_output_norm;
21+
qwen3vl_load_global_fn load_output_embd;
22+
23+
// Attention
24+
qwen3vl_load_layer_fn load_attn_norm;
25+
qwen3vl_load_layer_fn load_attn_q_norm;
26+
qwen3vl_load_layer_fn load_attn_k_norm;
27+
qwen3vl_load_layer_fn load_attn_qkv_proj;
28+
qwen3vl_load_layer_fn load_attn_o_proj;
29+
30+
// MLP
31+
qwen3vl_load_layer_fn load_mlp_norm;
32+
qwen3vl_load_layer_fn load_mlp_gate_up;
33+
qwen3vl_load_layer_fn load_mlp_down;
34+
35+
} Qwen3vlLangWeightLoader;
36+
37+
typedef struct {
38+
// Patch_embed
39+
qwen3vl_load_global_fn load_patch_embed_weight;
40+
qwen3vl_load_global_fn load_patch_embed_bias;
41+
qwen3vl_load_global_fn load_pos_embed_weight;
42+
43+
// blocks attn
44+
qwen3vl_load_layer_fn load_attn_proj_weight;
45+
qwen3vl_load_layer_fn load_attn_proj_bias;
46+
qwen3vl_load_layer_fn load_attn_qkv_weight;
47+
qwen3vl_load_layer_fn load_attn_qkv_bias;
48+
49+
// block mlp
50+
qwen3vl_load_layer_fn load_mlp_linear_fc1_weight;
51+
qwen3vl_load_layer_fn load_mlp_linear_fc1_bias;
52+
qwen3vl_load_layer_fn load_mlp_linear_fc2_weight;
53+
qwen3vl_load_layer_fn load_mlp_linear_fc2_bias;
54+
55+
// block norm
56+
qwen3vl_load_layer_fn load_norm1_weight;
57+
qwen3vl_load_layer_fn load_norm1_bias;
58+
qwen3vl_load_layer_fn load_norm2_weight;
59+
qwen3vl_load_layer_fn load_norm2_bias;
60+
61+
// deepstack_merger
62+
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight;
63+
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias;
64+
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight;
65+
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_bias;
66+
qwen3vl_load_layer_fn load_deepstack_merger_norm_weight;
67+
qwen3vl_load_layer_fn load_deepstack_merger_norm_bias;
68+
69+
// merger
70+
qwen3vl_load_global_fn load_merger_linear_fc1_weight;
71+
qwen3vl_load_global_fn load_merger_linear_fc1_bias;
72+
qwen3vl_load_global_fn load_merger_linear_fc2_weight;
73+
qwen3vl_load_global_fn load_merger_linear_fc2_bias;
74+
qwen3vl_load_global_fn load_merger_norm_weight;
75+
qwen3vl_load_global_fn load_merger_norm_bias;
76+
77+
} Qwen3vlVisWeightLoader;
78+
79+
typedef struct {
80+
Qwen3vlLangWeightLoader lang_loader;
81+
Qwen3vlVisWeightLoader vis_loader;
82+
} Qwen3vlWeightLoader;
83+
84+
struct Qwen3vlModel;
85+
86+
typedef struct {
87+
size_t bos_token_id;
88+
size_t eos_token_id;
89+
size_t head_dim;
90+
size_t hidden_size;
91+
float initializer_range;
92+
size_t intermediate_size;
93+
size_t max_tokens;
94+
size_t num_attention_heads;
95+
size_t num_hidden_layers;
96+
size_t num_key_value_heads;
97+
float rms_norm_eps;
98+
size_t mrope_section[3];
99+
size_t rope_theta;
100+
size_t vocab_size;
101+
} Qwen3vlTextMeta;
102+
103+
typedef struct {
104+
size_t depth;
105+
size_t deepstack_visual_indexes[3];
106+
size_t hidden_size;
107+
size_t in_channels;
108+
float initializer_range;
109+
size_t intermediate_size;
110+
size_t num_heads;
111+
size_t num_position_embeddings;
112+
size_t out_hidden_size;
113+
size_t patch_size;
114+
size_t spatial_merge_size;
115+
size_t temporal_patch_size;
116+
} Qwen3vlVisMeta;
117+
118+
typedef struct {
119+
infiniDtype_t dtype; // INFINI_DTYPE_BF16
120+
121+
Qwen3vlTextMeta text_meta;
122+
Qwen3vlVisMeta vis_meta;
123+
124+
size_t image_token_id;
125+
size_t video_token_id;
126+
size_t vision_end_token_id;
127+
size_t vision_start_token_id;
128+
} Qwen3vlMeta;
129+
130+
//////////////////// APIs ///////////////////////
131+
/// @brief 创建模型
132+
/// @param device 协处理器种类
133+
/// @param ndev 协处理器数量
134+
/// @param dev_ids 协处理器编号,长度为 ndev
135+
__INFINI_C __export struct Qwen3vlModel *
136+
createQwen3vlModel(const Qwen3vlMeta *,
137+
const Qwen3vlWeights *);
138+
139+
__INFINI_C Qwen3vlWeights *
140+
createQwen3vlWeights(const Qwen3vlMeta *meta,
141+
infiniDevice_t device,
142+
int ndev,
143+
const int *dev_ids,
144+
bool transpose_weight);
145+
146+
__INFINI_C __export Qwen3vlWeightLoader *
147+
createQwen3vlWeightLoader();
148+
149+
/// @brief 销毁模型
150+
__INFINI_C __export void destroyQwen3vlModel(struct Qwen3vlModel *);
151+
152+
__INFINI_C __export struct Qwen3vlCache *
153+
createQwen3vlCache(const struct Qwen3vlModel *);
154+
155+
__INFINI_C __export void
156+
dropQwen3vlCache(const struct Qwen3vlModel *,
157+
struct Qwen3vlCache *);
158+
159+
/// @brief 批次推理一轮,并采样出新的 token
160+
/// @param tokens 输入 token 地址
161+
/// @param ntok 输入 token 数量
162+
/// @param nreq 请求数量
163+
/// @param req_lens 每个请求的 token 数量
164+
/// @param req_pos 每个请求的起始位置
165+
/// @param kv_caches 每个请求的 KV Cache
166+
/// @param temperature 采样温度(0. 表示贪心采样)
167+
/// @param topk 采样 topk(1 表示贪心采样)
168+
/// @param topp 采样 topp
169+
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
170+
__INFINI_C __export void
171+
inferBatchQwen3vl(struct Qwen3vlModel *,
172+
const uint32_t *tokens, uint32_t ntok,
173+
void *pixel_values, uint32_t total_patches,
174+
uint32_t *image_grid_thw, uint32_t num_images,
175+
void *pixel_values_videos, uint32_t total_patches_videos,
176+
uint32_t *video_grid_thw, uint32_t num_videos,
177+
uint32_t patch_features,
178+
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
179+
struct Qwen3vlCache **caches,
180+
const float *temperature, const uint32_t *topk, const float *topp,
181+
uint32_t *output);
182+
183+
/// @brief 批次推理一轮,输出 output embedding 后的 logits
184+
/// @param tokens 输入 token 地址
185+
/// @param ntok 输入 token 数量
186+
/// @param nreq 请求数量
187+
/// @param req_lens 每个请求的 token 数量
188+
/// @param req_pos 每个请求的起始位置
189+
/// @param kv_caches 每个请求的 KV Cache
190+
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
191+
__INFINI_C __export void
192+
forwardBatchQwen3vl(struct Qwen3vlModel *,
193+
const uint32_t *tokens, uint32_t ntok,
194+
void *pixel_values, uint32_t total_patches,
195+
uint32_t *image_grid_thw, uint32_t num_images,
196+
void *pixel_values_videos, uint32_t total_patches_videos,
197+
uint32_t *video_grid_thw, uint32_t num_videos,
198+
uint32_t patch_features,
199+
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
200+
struct Qwen3vlCache **caches,
201+
void *logits);
202+
203+
#endif // QWEN3VL_WEIGHTS_H

scripts/libinfinicore_infer/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
DeepSeekV3WeightLoaderCStruct,
99
DeepSeekV3CacheCStruct,
1010
)
11+
from .qwen3vl import (
12+
Qwen3vlModel,
13+
Qwen3vlMetaCStruct,
14+
TextMetaCStruct,
15+
VisMetaCStruct,
16+
Qwen3vlWeightsCStruct,
17+
Qwen3vlWeightLoaderCStruct,
18+
Qwen3vlVisWeightLoaderCStruct,
19+
Qwen3vlLangWeightLoaderCStruct,
20+
Qwen3vlCacheCStruct,
21+
)
1122

1223
__all__ = [
1324
"DataType",
@@ -23,5 +34,15 @@
2334
"DeepSeekV3MetaCStruct",
2435
"DeepSeekV3WeightsCStruct",
2536
"DeepSeekV3WeightLoaderCStruct",
37+
"DeepSeekV3CacheCStruct",
38+
"Qwen3vlModel",
39+
"Qwen3vlMetaCStruct",
40+
"TextMetaCStruct",
41+
"VisMetaCStruct",
42+
"Qwen3vlWeightsCStruct",
43+
"Qwen3vlWeightLoaderCStruct",
44+
"Qwen3vlVisWeightLoaderCStruct",
45+
"Qwen3vlLangWeightLoaderCStruct",
46+
"Qwen3vlCacheCStruct",
2647
"ModelRegister",
2748
]

0 commit comments

Comments
 (0)