|
| 1 | +#ifndef QWEN3VL_WEIGHTS_H |
| 2 | +#define QWEN3VL_WEIGHTS_H |
| 3 | + |
| 4 | +#include <infiniccl.h> |
| 5 | +#include <infiniop.h> |
| 6 | +#include <infinirt.h> |
| 7 | + |
| 8 | +#include <stddef.h> |
| 9 | +#include <stdint.h> |
| 10 | + |
| 11 | +struct Qwen3vlWeights; |
| 12 | + |
| 13 | +// Function pointer signatures |
| 14 | +typedef void (*qwen3vl_load_global_fn)(Qwen3vlWeights *, void *cpu_ptr); |
| 15 | +typedef void (*qwen3vl_load_layer_fn)(Qwen3vlWeights *, void *cpu_ptr, size_t layer_id); |
| 16 | +// Struct containing all weight loading functions |
| 17 | +typedef struct { |
| 18 | + // Global |
| 19 | + qwen3vl_load_global_fn load_input_embd; |
| 20 | + qwen3vl_load_global_fn load_output_norm; |
| 21 | + qwen3vl_load_global_fn load_output_embd; |
| 22 | + |
| 23 | + // Attention |
| 24 | + qwen3vl_load_layer_fn load_attn_norm; |
| 25 | + qwen3vl_load_layer_fn load_attn_q_norm; |
| 26 | + qwen3vl_load_layer_fn load_attn_k_norm; |
| 27 | + qwen3vl_load_layer_fn load_attn_qkv_proj; |
| 28 | + qwen3vl_load_layer_fn load_attn_o_proj; |
| 29 | + |
| 30 | + // MLP |
| 31 | + qwen3vl_load_layer_fn load_mlp_norm; |
| 32 | + qwen3vl_load_layer_fn load_mlp_gate_up; |
| 33 | + qwen3vl_load_layer_fn load_mlp_down; |
| 34 | + |
| 35 | +} Qwen3vlLangWeightLoader; |
| 36 | + |
| 37 | +typedef struct { |
| 38 | + // Patch_embed |
| 39 | + qwen3vl_load_global_fn load_patch_embed_weight; |
| 40 | + qwen3vl_load_global_fn load_patch_embed_bias; |
| 41 | + qwen3vl_load_global_fn load_pos_embed_weight; |
| 42 | + |
| 43 | + // blocks attn |
| 44 | + qwen3vl_load_layer_fn load_attn_proj_weight; |
| 45 | + qwen3vl_load_layer_fn load_attn_proj_bias; |
| 46 | + qwen3vl_load_layer_fn load_attn_qkv_weight; |
| 47 | + qwen3vl_load_layer_fn load_attn_qkv_bias; |
| 48 | + |
| 49 | + // block mlp |
| 50 | + qwen3vl_load_layer_fn load_mlp_linear_fc1_weight; |
| 51 | + qwen3vl_load_layer_fn load_mlp_linear_fc1_bias; |
| 52 | + qwen3vl_load_layer_fn load_mlp_linear_fc2_weight; |
| 53 | + qwen3vl_load_layer_fn load_mlp_linear_fc2_bias; |
| 54 | + |
| 55 | + // block norm |
| 56 | + qwen3vl_load_layer_fn load_norm1_weight; |
| 57 | + qwen3vl_load_layer_fn load_norm1_bias; |
| 58 | + qwen3vl_load_layer_fn load_norm2_weight; |
| 59 | + qwen3vl_load_layer_fn load_norm2_bias; |
| 60 | + |
| 61 | + // deepstack_merger |
| 62 | + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight; |
| 63 | + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias; |
| 64 | + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight; |
| 65 | + qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_bias; |
| 66 | + qwen3vl_load_layer_fn load_deepstack_merger_norm_weight; |
| 67 | + qwen3vl_load_layer_fn load_deepstack_merger_norm_bias; |
| 68 | + |
| 69 | + // merger |
| 70 | + qwen3vl_load_global_fn load_merger_linear_fc1_weight; |
| 71 | + qwen3vl_load_global_fn load_merger_linear_fc1_bias; |
| 72 | + qwen3vl_load_global_fn load_merger_linear_fc2_weight; |
| 73 | + qwen3vl_load_global_fn load_merger_linear_fc2_bias; |
| 74 | + qwen3vl_load_global_fn load_merger_norm_weight; |
| 75 | + qwen3vl_load_global_fn load_merger_norm_bias; |
| 76 | + |
| 77 | +} Qwen3vlVisWeightLoader; |
| 78 | + |
| 79 | +typedef struct { |
| 80 | + Qwen3vlLangWeightLoader lang_loader; |
| 81 | + Qwen3vlVisWeightLoader vis_loader; |
| 82 | +} Qwen3vlWeightLoader; |
| 83 | + |
| 84 | +struct Qwen3vlModel; |
| 85 | + |
| 86 | +typedef struct { |
| 87 | + size_t bos_token_id; |
| 88 | + size_t eos_token_id; |
| 89 | + size_t head_dim; |
| 90 | + size_t hidden_size; |
| 91 | + float initializer_range; |
| 92 | + size_t intermediate_size; |
| 93 | + size_t max_tokens; |
| 94 | + size_t num_attention_heads; |
| 95 | + size_t num_hidden_layers; |
| 96 | + size_t num_key_value_heads; |
| 97 | + float rms_norm_eps; |
| 98 | + size_t mrope_section[3]; |
| 99 | + size_t rope_theta; |
| 100 | + size_t vocab_size; |
| 101 | +} Qwen3vlTextMeta; |
| 102 | + |
| 103 | +typedef struct { |
| 104 | + size_t depth; |
| 105 | + size_t deepstack_visual_indexes[3]; |
| 106 | + size_t hidden_size; |
| 107 | + size_t in_channels; |
| 108 | + float initializer_range; |
| 109 | + size_t intermediate_size; |
| 110 | + size_t num_heads; |
| 111 | + size_t num_position_embeddings; |
| 112 | + size_t out_hidden_size; |
| 113 | + size_t patch_size; |
| 114 | + size_t spatial_merge_size; |
| 115 | + size_t temporal_patch_size; |
| 116 | +} Qwen3vlVisMeta; |
| 117 | + |
| 118 | +typedef struct { |
| 119 | + infiniDtype_t dtype; // INFINI_DTYPE_BF16 |
| 120 | + |
| 121 | + Qwen3vlTextMeta text_meta; |
| 122 | + Qwen3vlVisMeta vis_meta; |
| 123 | + |
| 124 | + size_t image_token_id; |
| 125 | + size_t video_token_id; |
| 126 | + size_t vision_end_token_id; |
| 127 | + size_t vision_start_token_id; |
| 128 | +} Qwen3vlMeta; |
| 129 | + |
| 130 | +//////////////////// APIs /////////////////////// |
| 131 | +/// @brief 创建模型 |
| 132 | +/// @param device 协处理器种类 |
| 133 | +/// @param ndev 协处理器数量 |
| 134 | +/// @param dev_ids 协处理器编号,长度为 ndev |
| 135 | +__INFINI_C __export struct Qwen3vlModel * |
| 136 | +createQwen3vlModel(const Qwen3vlMeta *, |
| 137 | + const Qwen3vlWeights *); |
| 138 | + |
| 139 | +__INFINI_C Qwen3vlWeights * |
| 140 | +createQwen3vlWeights(const Qwen3vlMeta *meta, |
| 141 | + infiniDevice_t device, |
| 142 | + int ndev, |
| 143 | + const int *dev_ids, |
| 144 | + bool transpose_weight); |
| 145 | + |
| 146 | +__INFINI_C __export Qwen3vlWeightLoader * |
| 147 | +createQwen3vlWeightLoader(); |
| 148 | + |
| 149 | +/// @brief 销毁模型 |
| 150 | +__INFINI_C __export void destroyQwen3vlModel(struct Qwen3vlModel *); |
| 151 | + |
| 152 | +__INFINI_C __export struct Qwen3vlCache * |
| 153 | +createQwen3vlCache(const struct Qwen3vlModel *); |
| 154 | + |
| 155 | +__INFINI_C __export void |
| 156 | +dropQwen3vlCache(const struct Qwen3vlModel *, |
| 157 | + struct Qwen3vlCache *); |
| 158 | + |
| 159 | +/// @brief 批次推理一轮,并采样出新的 token |
| 160 | +/// @param tokens 输入 token 地址 |
| 161 | +/// @param ntok 输入 token 数量 |
| 162 | +/// @param nreq 请求数量 |
| 163 | +/// @param req_lens 每个请求的 token 数量 |
| 164 | +/// @param req_pos 每个请求的起始位置 |
| 165 | +/// @param kv_caches 每个请求的 KV Cache |
| 166 | +/// @param temperature 采样温度(0. 表示贪心采样) |
| 167 | +/// @param topk 采样 topk(1 表示贪心采样) |
| 168 | +/// @param topp 采样 topp |
| 169 | +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq |
| 170 | +__INFINI_C __export void |
| 171 | +inferBatchQwen3vl(struct Qwen3vlModel *, |
| 172 | + const uint32_t *tokens, uint32_t ntok, |
| 173 | + void *pixel_values, uint32_t total_patches, |
| 174 | + uint32_t *image_grid_thw, uint32_t num_images, |
| 175 | + void *pixel_values_videos, uint32_t total_patches_videos, |
| 176 | + uint32_t *video_grid_thw, uint32_t num_videos, |
| 177 | + uint32_t patch_features, |
| 178 | + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, |
| 179 | + struct Qwen3vlCache **caches, |
| 180 | + const float *temperature, const uint32_t *topk, const float *topp, |
| 181 | + uint32_t *output); |
| 182 | + |
| 183 | +/// @brief 批次推理一轮,输出 output embedding 后的 logits |
| 184 | +/// @param tokens 输入 token 地址 |
| 185 | +/// @param ntok 输入 token 数量 |
| 186 | +/// @param nreq 请求数量 |
| 187 | +/// @param req_lens 每个请求的 token 数量 |
| 188 | +/// @param req_pos 每个请求的起始位置 |
| 189 | +/// @param kv_caches 每个请求的 KV Cache |
| 190 | +/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq |
| 191 | +__INFINI_C __export void |
| 192 | +forwardBatchQwen3vl(struct Qwen3vlModel *, |
| 193 | + const uint32_t *tokens, uint32_t ntok, |
| 194 | + void *pixel_values, uint32_t total_patches, |
| 195 | + uint32_t *image_grid_thw, uint32_t num_images, |
| 196 | + void *pixel_values_videos, uint32_t total_patches_videos, |
| 197 | + uint32_t *video_grid_thw, uint32_t num_videos, |
| 198 | + uint32_t patch_features, |
| 199 | + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, |
| 200 | + struct Qwen3vlCache **caches, |
| 201 | + void *logits); |
| 202 | + |
| 203 | +#endif // QWEN3VL_WEIGHTS_H |
0 commit comments