Skip to content

Commit cd3bb2b

Browse files
issue/343 optimize siglip attention
1 parent d493491 commit cd3bb2b

8 files changed

Lines changed: 82 additions & 58 deletions

File tree

csrc/layers/linear/fused_linear.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
3131
const infinicore::Device &device,
3232
engine::distributed::RankInfo rank_info)
3333
: infinilm::nn::ColumnParallelLinear(
34-
hidden_size,
35-
calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info),
36-
quantization,
37-
(q_bias || k_bias || v_bias),
38-
dtype,
39-
device,
40-
rank_info.tp_rank,
41-
rank_info.tp_size),
34+
hidden_size,
35+
calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info),
36+
quantization == nullptr ? std::make_shared<infinilm::quantization::NoneQuantization>() : quantization,
37+
(q_bias || k_bias || v_bias),
38+
dtype,
39+
device,
40+
rank_info.tp_rank,
41+
rank_info.tp_size),
4242
q_dim_(q_dim),
4343
k_dim_(k_dim),
4444
v_dim_(v_dim),
@@ -120,7 +120,17 @@ GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermedia
120120
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
121121
const infinicore::DataType &dtype, const infinicore::Device &device,
122122
engine::distributed::RankInfo rank_info)
123-
: infinilm::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, quantization, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
123+
: infinilm::nn::ColumnParallelLinear(
124+
hidden_size,
125+
intermediate_size * 2,
126+
quantization == nullptr ? std::make_shared<infinilm::quantization::NoneQuantization>() : quantization,
127+
gate_bias || up_bias,
128+
dtype,
129+
device,
130+
rank_info.tp_rank,
131+
rank_info.tp_size),
132+
gate_bias_(gate_bias),
133+
up_bias_(up_bias) {
124134
if (gate_bias_ != up_bias_) {
125135
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
126136
}

csrc/layers/linear/fused_linear.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22
#include "../../engine/distributed/communication_group.hpp"
3-
#include "linear.hpp"
43
#include "../quantization/quantization.hpp"
4+
#include "linear.hpp"
55
#include <functional>
66

77
namespace infinilm::layers::linear {
@@ -13,15 +13,15 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear {
1313
size_t q_dim, size_t k_dim, size_t v_dim,
1414
size_t num_q_head, size_t num_k_head, size_t num_v_head,
1515
bool q_bias, bool k_bias, bool v_bias,
16-
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
16+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
1717
const infinicore::DataType &dtype = infinicore::DataType::F32,
1818
const infinicore::Device &device = infinicore::Device(),
1919
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
2020

2121
explicit QKVParallelLinear(size_t hidden_size,
2222
size_t head_dim,
2323
size_t num_q_head, size_t num_kv_head,
24-
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
24+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
2525
bool bias = false,
2626
const infinicore::DataType &dtype = infinicore::DataType::F32,
2727
const infinicore::Device &device = infinicore::Device(),
@@ -32,7 +32,7 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear {
3232
size_t num_q_head, size_t num_kv_head,
3333
const std::string &q_name, const std::string &k_name, const std::string &v_name,
3434
RegisterParamFn register_fn,
35-
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
35+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
3636
bool bias = false,
3737
const infinicore::DataType &dtype = infinicore::DataType::F32,
3838
const infinicore::Device &device = infinicore::Device(),
@@ -84,21 +84,22 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear {
8484

8585
class GateUpParallelLinear : public infinilm::nn::ColumnParallelLinear {
8686
public:
87-
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
87+
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size,
88+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
8889
bool bias = false,
8990
const infinicore::DataType &dtype = infinicore::DataType::F32,
9091
const infinicore::Device &device = infinicore::Device(),
9192
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
9293

9394
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
94-
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
95+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
9596
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
9697
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
9798

9899
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size,
99100
const std::string &gate_name, const std::string &up_name,
100101
RegisterParamFn register_fn,
101-
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
102+
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
102103
bool bias = false,
103104
const infinicore::DataType &dtype = infinicore::DataType::F32,
104105
const infinicore::Device &device = infinicore::Device(),

csrc/layers/quantization/none_quantization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
namespace infinilm::quantization {
66

7+
NoneQuantization::NoneQuantization() : NoneQuantization(nlohmann::json()) {}
8+
79
std::vector<ParamDescriptor> NoneQuantization::get_param_layout(
810
size_t in_features, size_t out_features,
911
int split_dim, int tp_rank, int tp_size,
@@ -14,8 +16,7 @@ std::vector<ParamDescriptor> NoneQuantization::get_param_layout(
1416
std::vector<ParamDescriptor> descs;
1517
descs.push_back({"weight", {out_features, in_features}, dtype, split_dim, tp_rank, tp_size});
1618
if (bias) {
17-
descs.push_back({"bias", {out_features}, dtype, split_dim >= 0 ? 0 : -1,
18-
split_dim >= 0 ? tp_rank : 0, split_dim >= 0 ? tp_size : 1});
19+
descs.push_back({"bias", {out_features}, dtype, split_dim >= 0 ? 0 : -1, split_dim >= 0 ? tp_rank : 0, split_dim >= 0 ? tp_size : 1});
1920
}
2021
return descs;
2122
}

csrc/layers/quantization/none_quantization.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ namespace infinilm::quantization {
66
class NoneQuantization : public BaseQuantization {
77
public:
88
explicit NoneQuantization(const nlohmann::json &quant_config)
9-
: BaseQuantization(quant_config) {};
9+
: BaseQuantization(quant_config){};
10+
11+
NoneQuantization();
1012

1113
QuantScheme get_quant_scheme() const override {
1214
return QuantScheme::NONE;

csrc/models/minicpmv/resampler.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

33
#include "../../config/model_config.hpp"
4+
#include "../../layers/linear/fused_linear.hpp"
45
#include "infinicore/nn/layer_norm.hpp"
5-
#include "infinicore/nn/linear.hpp"
66
#include "infinicore/nn/module.hpp"
77
#include "infinicore/tensor.hpp"
88

@@ -30,7 +30,7 @@ class ResamplerAttention : public infinicore::nn::Module {
3030

3131
INFINICORE_NN_PARAMETER(in_proj_weight);
3232
INFINICORE_NN_PARAMETER(in_proj_bias);
33-
INFINICORE_NN_MODULE(infinicore::nn::Linear, out_proj);
33+
INFINICORE_NN_MODULE(infinilm::nn::Linear, out_proj);
3434
};
3535

3636
class Resampler : public infinicore::nn::Module {
@@ -59,7 +59,7 @@ class Resampler : public infinicore::nn::Module {
5959
INFINICORE_NN_PARAMETER(query);
6060
INFINICORE_NN_PARAMETER(proj);
6161
INFINICORE_NN_BUFFER(embedding_table);
62-
INFINICORE_NN_MODULE(infinicore::nn::Linear, kv_proj);
62+
INFINICORE_NN_MODULE(infinilm::nn::Linear, kv_proj);
6363
INFINICORE_NN_MODULE(ResamplerAttention, attn);
6464
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_q);
6565
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_kv);

csrc/models/minicpmv/siglip_vision.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "siglip_vision.hpp"
22

3+
#include "../../global_state/global_state.hpp"
34
#include "infinicore/ops.hpp"
5+
#include "infinicore/ops/mha.hpp"
46

57
#include <cmath>
68
#include <cstring>
@@ -92,44 +94,52 @@ SiglipAttention::SiglipAttention(const nlohmann::json &config,
9294
if (embed_dim_ % num_heads_ != 0) {
9395
throw std::runtime_error("SiglipAttention: embed_dim must be divisible by num_heads");
9496
}
95-
INFINICORE_NN_MODULE_INIT(q_proj, embed_dim_, embed_dim_, true, dtype, device);
96-
INFINICORE_NN_MODULE_INIT(k_proj, embed_dim_, embed_dim_, true, dtype, device);
97-
INFINICORE_NN_MODULE_INIT(v_proj, embed_dim_, embed_dim_, true, dtype, device);
97+
qkv_proj_ = std::make_shared<infinilm::layers::linear::QKVParallelLinear>(
98+
embed_dim_, head_dim_, num_heads_, num_heads_,
99+
"q_proj", "k_proj", "v_proj", [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); },
100+
nullptr, true, dtype, device);
101+
98102
INFINICORE_NN_MODULE_INIT(out_proj, embed_dim_, embed_dim_, true, dtype, device);
103+
104+
attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend;
99105
}
100106

101107
infinicore::Tensor SiglipAttention::forward(const infinicore::Tensor &hidden_states,
102108
const std::optional<infinicore::Tensor> &attention_mask) const {
103-
(void)attention_mask;
104109
auto shape = hidden_states->shape();
105110
size_t batch_size = shape[0];
106111
size_t seq_len = shape[1];
107112

108-
auto q = q_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states));
109-
auto k = k_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states));
110-
auto v = v_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states));
111-
112-
auto q_reshaped = q->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
113-
auto k_reshaped = k->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
114-
auto v_reshaped = v->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
115-
116-
auto q_flat = q_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
117-
auto k_flat = k_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
118-
auto v_flat = v_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
119-
120-
auto k_t = k_flat->permute({0, 2, 1});
121-
auto attn_weights = infinicore::op::matmul(q_flat, k_t, scale_);
113+
auto qkv = qkv_proj_->forward(const_cast<infinicore::Tensor &>(hidden_states))->view({batch_size, seq_len, num_heads_ * 3, head_dim_});
114+
auto q = qkv->narrow({{2, 0, num_heads_}});
115+
auto k = qkv->narrow({{2, num_heads_, num_heads_}});
116+
auto v = qkv->narrow({{2, num_heads_ * 2, num_heads_}});
122117

123-
auto attn_view = attn_weights->view({batch_size * num_heads_, seq_len, seq_len});
124-
infinicore::op::softmax_(attn_view, attn_view, -1);
125-
126-
auto attn_output = infinicore::op::matmul(attn_weights, v_flat);
127-
auto out = attn_output->view({batch_size, num_heads_, seq_len, head_dim_})
128-
->permute({0, 2, 1, 3})
129-
->contiguous()
130-
->view({batch_size, seq_len, embed_dim_});
131-
132-
return out_proj_->forward(out);
118+
if (attention_backend_ == infinilm::backends::AttentionBackend::FLASH_ATTN) {
119+
auto out = infinicore::op::mha(q, k, v, std::nullopt, scale_, false)->view({batch_size, seq_len, num_heads_ * head_dim_});
120+
return out_proj_->forward(out);
121+
} else {
122+
auto q_reshaped = q->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
123+
auto k_reshaped = k->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
124+
auto v_reshaped = v->view({batch_size, seq_len, num_heads_, head_dim_})->permute({0, 2, 1, 3})->contiguous();
125+
126+
auto q_flat = q_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
127+
auto k_flat = k_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
128+
auto v_flat = v_reshaped->view({batch_size * num_heads_, seq_len, head_dim_});
129+
130+
auto k_t = k_flat->permute({0, 2, 1});
131+
auto attn_weights = infinicore::op::matmul(q_flat, k_t, scale_);
132+
133+
auto attn_view = attn_weights->view({batch_size * num_heads_, seq_len, seq_len});
134+
infinicore::op::softmax_(attn_view, attn_view, -1);
135+
136+
auto attn_output = infinicore::op::matmul(attn_weights, v_flat);
137+
auto out = attn_output->view({batch_size, num_heads_, seq_len, head_dim_})
138+
->permute({0, 2, 1, 3})
139+
->contiguous()
140+
->view({batch_size, seq_len, embed_dim_});
141+
return out_proj_->forward(out);
142+
}
133143
}
134144

135145
SiglipMLP::SiglipMLP(const nlohmann::json &config,

csrc/models/minicpmv/siglip_vision.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#pragma once
22

3+
#include "../../backends/attention_backends.hpp"
34
#include "../../config/model_config.hpp"
5+
#include "../../layers/linear/fused_linear.hpp"
46
#include "infinicore/nn/embedding.hpp"
57
#include "infinicore/nn/layer_norm.hpp"
6-
#include "infinicore/nn/linear.hpp"
78
#include "infinicore/nn/module.hpp"
89
#include "infinicore/tensor.hpp"
910
#include <nlohmann/json.hpp>
@@ -61,11 +62,10 @@ class SiglipAttention : public infinicore::nn::Module {
6162
size_t num_heads_;
6263
size_t head_dim_;
6364
float scale_;
65+
infinilm::backends::AttentionBackend attention_backend_;
6466

65-
INFINICORE_NN_MODULE(infinicore::nn::Linear, q_proj);
66-
INFINICORE_NN_MODULE(infinicore::nn::Linear, k_proj);
67-
INFINICORE_NN_MODULE(infinicore::nn::Linear, v_proj);
68-
INFINICORE_NN_MODULE(infinicore::nn::Linear, out_proj);
67+
INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj);
68+
INFINICORE_NN_MODULE(infinilm::nn::Linear, out_proj);
6969
};
7070

7171
class SiglipMLP : public infinicore::nn::Module {
@@ -78,8 +78,8 @@ class SiglipMLP : public infinicore::nn::Module {
7878

7979
private:
8080
std::string activation_;
81-
INFINICORE_NN_MODULE(infinicore::nn::Linear, fc1);
82-
INFINICORE_NN_MODULE(infinicore::nn::Linear, fc2);
81+
INFINICORE_NN_MODULE(infinilm::nn::Linear, fc1);
82+
INFINICORE_NN_MODULE(infinilm::nn::Linear, fc2);
8383
};
8484

8585
class SiglipEncoderLayer : public infinicore::nn::Module {

test/service/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def main():
134134
)
135135

136136
client = AsyncOpenAI(base_url=api_url, api_key="default")
137-
asyncio.run(benchmark_user(client, messages, args.model_name))
137+
asyncio.run(benchmark_user(client, messages, args.model))
138138

139139

140140
if __name__ == "__main__":

0 commit comments

Comments
 (0)