|
| 1 | +#ifndef MODEL_DEEPSEEK_OCR_H |
| 2 | +#define MODEL_DEEPSEEK_OCR_H |
| 3 | + |
| 4 | +#include <infiniccl.h> |
| 5 | +#include <infiniop.h> |
| 6 | +#include <infinirt.h> |
| 7 | + |
| 8 | +#include <stdint.h> |
| 9 | + |
| 10 | +struct DeepSeekOCRModel; |
| 11 | + |
| 12 | +typedef struct |
| 13 | +{ |
| 14 | + infiniDtype_t dt_logits; |
| 15 | + infiniDtype_t dt_norm; |
| 16 | + // Layer counts |
| 17 | + size_t n_dense_layer; // 第0层是dense |
| 18 | + size_t n_sparse_layer; // 第1-11层是MoE |
| 19 | + // Model dimensions |
| 20 | + size_t d; // hidden_size: 1280 |
| 21 | + size_t nh; // num_attention_heads: 1280 |
| 22 | + size_t nkvh; // num_key_value_heads: 1280 |
| 23 | + size_t dh; // head_dim: d/nh = 1 |
| 24 | + // Dense MLP dimensions |
| 25 | + size_t di_dense; // intermediate_size for dense layer: 6848 |
| 26 | + // MoE dimensions |
| 27 | + size_t di_moe; // moe_intermediate_size: 896 |
| 28 | + size_t di_shared; // shared_expert_intermediate_size: 1792 |
| 29 | + size_t nexperts; // n_routed_experts: 64 |
| 30 | + size_t kexperts; // num_experts_per_tok: 6 |
| 31 | + float routed_scale; // routed_scaling_factor: 1.0 |
| 32 | + // Context and vocab |
| 33 | + size_t dctx; // max_position_embeddings |
| 34 | + size_t dvoc; // vocab_size: 129280 |
| 35 | + // Normalization |
| 36 | + float epsilon; // rms_norm_eps: 1e-6 |
| 37 | + float theta; // rope_theta: 10000.0 |
| 38 | + uint32_t end_token; // eos_token_id |
| 39 | +} DeepSeekOCRMeta; |
| 40 | + |
| 41 | +typedef struct |
| 42 | +{ |
| 43 | + size_t n_dense_layer; |
| 44 | + size_t n_sparse_layer; |
| 45 | + infiniDtype_t dt_norm, dt_mat; |
| 46 | + // 0 if linear weights are passed as W, any other value if passed as W^T |
| 47 | + int transpose_linear_weights; |
| 48 | + |
| 49 | + // Embeddings |
| 50 | + const void *input_embd; // [dvoc, d] |
| 51 | + const void *output_norm; // [d] |
| 52 | + const void *output_embd; // [dvoc, d] |
| 53 | + |
| 54 | + // Attention layers (all layers: n_dense_layer + n_sparse_layer) |
| 55 | + const void *const *attn_norm; // nlayer * [d] |
| 56 | + const void *const *attn_q; // nlayer * [d, d] or sharded |
| 57 | + const void *const *attn_k; // nlayer * [d, d] or sharded |
| 58 | + const void *const *attn_v; // nlayer * [d, d] or sharded |
| 59 | + const void *const *attn_o; // nlayer * [d, d] or sharded |
| 60 | + |
| 61 | + // FFN layers |
| 62 | + const void *const *ffn_norm; // nlayer * [d] |
| 63 | + |
| 64 | + // Dense MLP (layer 0) |
| 65 | + const void *dense_gate; // [di_dense, d] |
| 66 | + const void *dense_up; // [di_dense, d] |
| 67 | + const void *dense_down; // [d, di_dense] |
| 68 | + |
| 69 | + // MoE layers (layer 1-11) |
| 70 | + const void *const *moe_gate_weight; // n_sparse_layer * [nexperts, d] |
| 71 | + const void *const *moe_gate_bias; // n_sparse_layer * [nexperts] |
| 72 | + |
| 73 | + // Shared experts |
| 74 | + const void *const *moe_shared_gate; // n_sparse_layer * [di_shared, d] |
| 75 | + const void *const *moe_shared_up; // n_sparse_layer * [di_shared, d] |
| 76 | + const void *const *moe_shared_down; // n_sparse_layer * [d, di_shared] |
| 77 | + |
| 78 | + // Routed experts |
| 79 | + const void *const *const *moe_experts_gate; // n_sparse_layer * nexperts * [di_moe, d] |
| 80 | + const void *const *const *moe_experts_up; // n_sparse_layer * nexperts * [di_moe, d] |
| 81 | + const void *const *const *moe_experts_down; // n_sparse_layer * nexperts * [d, di_moe] |
| 82 | + |
| 83 | + // Vision Encoder weights |
| 84 | + // SAM ViT-B |
| 85 | + const void *sam_patch_embed; |
| 86 | + const void *sam_patch_embed_bias; |
| 87 | + const void *const *sam_block_norm1; // 12 layers |
| 88 | + const void *const *sam_block_attn_qkv; // 12 layers |
| 89 | + const void *const *sam_block_attn_proj; // 12 layers |
| 90 | + const void *const *sam_block_norm2; // 12 layers |
| 91 | + const void *const *sam_block_mlp_fc1; // 12 layers |
| 92 | + const void *const *sam_block_mlp_fc2; // 12 layers |
| 93 | + const void *sam_neck_conv1; |
| 94 | + const void *sam_neck_ln1; |
| 95 | + const void *sam_neck_conv2; |
| 96 | + const void *sam_neck_ln2; |
| 97 | + |
| 98 | + // CLIP-L |
| 99 | + const void *clip_patch_embed; |
| 100 | + const void *clip_patch_embed_bias; |
| 101 | + const void *clip_position_embed; |
| 102 | + const void *clip_pre_layernorm; |
| 103 | + const void *const *clip_block_ln1; // 24 layers |
| 104 | + const void *const *clip_block_attn_qkv; // 24 layers |
| 105 | + const void *const *clip_block_attn_proj; // 24 layers |
| 106 | + const void *const *clip_block_ln2; // 24 layers |
| 107 | + const void *const *clip_block_mlp_fc1; // 24 layers |
| 108 | + const void *const *clip_block_mlp_fc2; // 24 layers |
| 109 | + |
| 110 | + // Projector |
| 111 | + const void *projector; // [2048, 1280] Linear projection |
| 112 | + const void *image_newline; // [1280] Image row separator |
| 113 | + const void *view_seperator; // [1280] View separator |
| 114 | +} DeepSeekOCRWeights; |
| 115 | + |
| 116 | +//////////////////// APIs /////////////////////// |
| 117 | + |
| 118 | +/// @brief 创建DeepSeek-OCR模型 |
| 119 | +/// @param device 协处理器种类 |
| 120 | +/// @param ndev 协处理器数量 |
| 121 | +/// @param dev_ids 协处理器编号,长度为 ndev |
| 122 | +__C __export struct DeepSeekOCRModel * |
| 123 | +createDeepSeekOCRModel(const DeepSeekOCRMeta *, |
| 124 | + const DeepSeekOCRWeights *, |
| 125 | + infiniDevice_t device, |
| 126 | + int ndev, |
| 127 | + const int *dev_ids); |
| 128 | + |
| 129 | +/// @brief 销毁模型 |
| 130 | +__C __export void |
| 131 | +destroyDeepSeekOCRModel(struct DeepSeekOCRModel *); |
| 132 | + |
| 133 | +/// @brief 批次推理一轮,并采样出新的 token |
| 134 | +/// @param tokens 输入 token 地址 |
| 135 | +/// @param ntok 输入 token 数量 |
| 136 | +/// @param nreq 请求数量 |
| 137 | +/// @param req_lens 每个请求的 token 数量 |
| 138 | +/// @param req_pos 每个请求的起始位置 |
| 139 | +/// @param kv_caches 每个请求的 KV Cache |
| 140 | +/// @param temperature 采样温度(0. 表示贪心采样) |
| 141 | +/// @param topk 采样 topk(1 表示贪心采样) |
| 142 | +/// @param topp 采样 topp |
| 143 | +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq |
| 144 | +__C __export void |
| 145 | +inferBatchDeepSeekOCR(struct DeepSeekOCRModel *, |
| 146 | + const uint32_t *tokens, uint32_t ntok, |
| 147 | + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, |
| 148 | + struct KVCache **kv_caches, |
| 149 | + const float *temperature, const uint32_t *topk, const float *topp, |
| 150 | + uint32_t *output); |
| 151 | + |
| 152 | +/// @brief 批次推理一轮,输出 output embedding 后的 logits |
| 153 | +/// @param tokens 输入 token 地址 |
| 154 | +/// @param ntok 输入 token 数量 |
| 155 | +/// @param nreq 请求数量 |
| 156 | +/// @param req_lens 每个请求的 token 数量 |
| 157 | +/// @param req_pos 每个请求的起始位置 |
| 158 | +/// @param kv_caches 每个请求的 KV Cache |
| 159 | +/// @param logits 输出 logits,shape: [ntok, dvoc] |
| 160 | +__C __export void |
| 161 | +forwardBatchDeepSeekOCR(struct DeepSeekOCRModel *, |
| 162 | + const uint32_t *tokens, uint32_t ntok, |
| 163 | + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, |
| 164 | + struct KVCache **kv_caches, |
| 165 | + void *logits); |
| 166 | + |
| 167 | +/// @brief 使用预计算的embeddings进行推理(用于多模态输入) |
| 168 | +/// @param inputs_embeds 输入embeddings,shape: [ntok, d] |
| 169 | +/// @param ntok 输入 token 数量 |
| 170 | +/// @param nreq 请求数量 |
| 171 | +/// @param req_lens 每个请求的 token 数量 |
| 172 | +/// @param req_pos 每个请求的起始位置 |
| 173 | +/// @param kv_caches 每个请求的 KV Cache |
| 174 | +/// @param temperature 采样温度 |
| 175 | +/// @param topk 采样 topk |
| 176 | +/// @param topp 采样 topp |
| 177 | +/// @param output 输出 token 数组 |
| 178 | +__C __export void |
| 179 | +inferBatchDeepSeekOCRWithEmbeds(struct DeepSeekOCRModel *, |
| 180 | + const void *inputs_embeds, uint32_t ntok, |
| 181 | + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, |
| 182 | + struct KVCache **kv_caches, |
| 183 | + const float *temperature, const uint32_t *topk, const float *topp, |
| 184 | + uint32_t *output); |
| 185 | + |
| 186 | +#endif |
0 commit comments