|
1 | 1 | #include "siglip_vision.hpp" |
2 | 2 |
|
| 3 | +#include "../../global_state/global_state.hpp" |
3 | 4 | #include "infinicore/ops.hpp" |
| 5 | +#include "infinicore/ops/mha.hpp" |
4 | 6 |
|
5 | 7 | #include <cmath> |
6 | 8 | #include <cstring> |
@@ -92,44 +94,52 @@ SiglipAttention::SiglipAttention(const nlohmann::json &config, |
92 | 94 | if (embed_dim_ % num_heads_ != 0) { |
93 | 95 | throw std::runtime_error("SiglipAttention: embed_dim must be divisible by num_heads"); |
94 | 96 | } |
95 | | - INFINICORE_NN_MODULE_INIT(q_proj, embed_dim_, embed_dim_, true, dtype, device); |
96 | | - INFINICORE_NN_MODULE_INIT(k_proj, embed_dim_, embed_dim_, true, dtype, device); |
97 | | - INFINICORE_NN_MODULE_INIT(v_proj, embed_dim_, embed_dim_, true, dtype, device); |
| 97 | + qkv_proj_ = std::make_shared<infinilm::layers::linear::QKVParallelLinear>( |
| 98 | + embed_dim_, head_dim_, num_heads_, num_heads_, |
| 99 | + "q_proj", "k_proj", "v_proj", [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }, |
| 100 | + nullptr, true, dtype, device); |
| 101 | + |
98 | 102 | INFINICORE_NN_MODULE_INIT(out_proj, embed_dim_, embed_dim_, true, dtype, device); |
| 103 | + |
| 104 | + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; |
99 | 105 | } |
100 | 106 |
|
101 | 107 | infinicore::Tensor SiglipAttention::forward(const infinicore::Tensor &hidden_states, |
102 | 108 | const std::optional<infinicore::Tensor> &attention_mask) const { |
103 | | - (void)attention_mask; |
104 | 109 | auto shape = hidden_states->shape(); |
105 | 110 | size_t batch_size = shape[0]; |
106 | 111 | size_t seq_len = shape[1]; |
107 | 112 |
|
108 | | - auto q = q_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states)); |
109 | | - auto k = k_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states)); |
110 | | - auto v = v_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states)); |
111 | | - |
112 | | - auto q_reshaped = q->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
113 | | - auto k_reshaped = k->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
114 | | - auto v_reshaped = v->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
115 | | - |
116 | | - auto q_flat = q_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
117 | | - auto k_flat = k_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
118 | | - auto v_flat = v_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
119 | | - |
120 | | - auto k_t = k_flat->permute({0, 2, 1}); |
121 | | - auto attn_weights = infinicore::op::matmul(q_flat, k_t, scale_); |
| 113 | + auto qkv = qkv_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states))->view({batch_size, seq_len, num_heads_ * 3, head_dim_}); |
| 114 | + auto q = qkv->narrow({{2, 0, num_heads_}}); |
| 115 | + auto k = qkv->narrow({{2, num_heads_, num_heads_}}); |
| 116 | + auto v = qkv->narrow({{2, num_heads_ * 2, num_heads_}}); |
122 | 117 |
|
123 | | - auto attn_view = attn_weights->view({batch_size * num_heads_, seq_len, seq_len}); |
124 | | - infinicore::op::softmax_(attn_view, attn_view, -1); |
125 | | - |
126 | | - auto attn_output = infinicore::op::matmul(attn_weights, v_flat); |
127 | | - auto out = attn_output->view({batch_size, num_heads_, seq_len, head_dim_}) |
128 | | - ->permute({0, 2, 1, 3}) |
129 | | - ->contiguous() |
130 | | - ->view({batch_size, seq_len, embed_dim_}); |
131 | | - |
132 | | - return out_proj_->forward(out); |
| 118 | + if (attention_backend_ == infinilm::backends::AttentionBackend::FLASH_ATTN) { |
| 119 | + auto out = infinicore::op::mha(q, k, v, std::nullopt, scale_, false)->view({batch_size, seq_len, num_heads_ * head_dim_}); |
| 120 | + return out_proj_->forward(out); |
| 121 | + } else { |
| 122 | + auto q_reshaped = q->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
| 123 | + auto k_reshaped = k->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
| 124 | + auto v_reshaped = v->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous(); |
| 125 | + |
| 126 | + auto q_flat = q_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
| 127 | + auto k_flat = k_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
| 128 | + auto v_flat = v_reshaped->view({batch_size * num_heads_, seq_len, head_dim_}); |
| 129 | + |
| 130 | + auto k_t = k_flat->permute({0, 2, 1}); |
| 131 | + auto attn_weights = infinicore::op::matmul(q_flat, k_t, scale_); |
| 132 | + |
| 133 | + auto attn_view = attn_weights->view({batch_size * num_heads_, seq_len, seq_len}); |
| 134 | + infinicore::op::softmax_(attn_view, attn_view, -1); |
| 135 | + |
| 136 | + auto attn_output = infinicore::op::matmul(attn_weights, v_flat); |
| 137 | + auto out = attn_output->view({batch_size, num_heads_, seq_len, head_dim_}) |
| 138 | + ->permute({0, 2, 1, 3}) |
| 139 | + ->contiguous() |
| 140 | + ->view({batch_size, seq_len, embed_dim_}); |
| 141 | + return out_proj_->forward(out); |
| 142 | + } |
133 | 143 | } |
134 | 144 |
|
135 | 145 | SiglipMLP::SiglipMLP(const nlohmann::json &config, |
|
0 commit comments