Skip to content

Commit cf40f94

Browse files
committed
feat: hybrid attn between infinicore and fa2
1 parent 2c563ad commit cf40f94

12 files changed

Lines changed: 357 additions & 6 deletions

csrc/backends/attention_backends.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ enum class AttentionBackend {
1313
STATIC_ATTN,
1414
PAGED_ATTN,
1515
FLASH_ATTN,
16+
FLASH_PREFILL,
17+
FLASH_DECODE,
1618
FLASHINFER,
1719
Default = STATIC_ATTN
1820
};
@@ -25,6 +27,10 @@ inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) {
2527
return os << "AttentionBackend::PAGED_ATTN";
2628
case AttentionBackend::FLASH_ATTN:
2729
return os << "AttentionBackend::FLASH_ATTN";
30+
case AttentionBackend::FLASH_PREFILL:
31+
return os << "AttentionBackend::FLASH_PREFILL";
32+
case AttentionBackend::FLASH_DECODE:
33+
return os << "AttentionBackend::FLASH_DECODE";
2834
case AttentionBackend::FLASHINFER:
2935
return os << "AttentionBackend::FLASHINFER";
3036
default:
@@ -46,12 +52,18 @@ inline AttentionBackend parse_attention_backend(const std::string &backend) {
4652
if (backend == "flash-attn") {
4753
return AttentionBackend::FLASH_ATTN;
4854
}
55+
if (backend == "flash-prefill") {
56+
return AttentionBackend::FLASH_PREFILL;
57+
}
58+
if (backend == "flash-decode") {
59+
return AttentionBackend::FLASH_DECODE;
60+
}
4961
if (backend == "flashinfer") {
5062
return AttentionBackend::FLASHINFER;
5163
}
5264

5365
throw std::invalid_argument(
54-
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer");
66+
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flash-prefill, flash-decode, flashinfer");
5567
}
5668

5769
} // namespace infinilm::backends

csrc/cache/kv_cache.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
255255
size_t block_size = config.block_size();
256256

257257
infinicore::Shape kv_shape;
258-
if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN) {
258+
if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN ||
259+
global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_PREFILL ||
260+
global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_DECODE) {
259261
// FLASH_ATTN kernel expects BSHD layout
260262
kv_shape = {2, num_blocks_per_layer, block_size, num_rank_k_heads, k_dim};
261263
} else {

csrc/layers/attention/backends/attention_layer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ AttentionLayer::AttentionLayer(size_t num_heads,
2020
case ::infinilm::backends::AttentionBackend::FLASH_ATTN:
2121
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
2222
break;
23+
case ::infinilm::backends::AttentionBackend::FLASH_PREFILL:
24+
attn_backend_impl_ = std::make_shared<backends::FlashPrefillAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
25+
break;
26+
case ::infinilm::backends::AttentionBackend::FLASH_DECODE:
27+
attn_backend_impl_ = std::make_shared<backends::FlashDecodeAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
28+
break;
2329
default:
2430
throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend");
2531
}

csrc/layers/attention/backends/attention_layer.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
#include "../../../backends/attention_backends.hpp"
44
#include "../../../global_state/global_state.hpp"
5-
#include "flash_attn.hpp"
65
#include "infinicore/tensor.hpp"
6+
#include "flash_attn.hpp"
7+
#include "flash_decode_attn.hpp"
8+
#include "flash_prefill_attn.hpp"
79
#include "paged_attn.hpp"
810
#include "static_attn.hpp"
911
#include <memory>
1012
#include <variant>
1113

1214
namespace infinilm::layers::attention {
13-
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>>;
15+
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>, std::shared_ptr<backends::FlashPrefillAttentionImpl>, std::shared_ptr<backends::FlashDecodeAttentionImpl>>;
1416

1517
/**
1618
* @brief Attention layer.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "flash_decode_attn.hpp"
2+
3+
#include "../../../utils.hpp"
4+
#include "infinicore/ops.hpp"
5+
#include "infinicore/ops/mha_kvcache.hpp"
6+
7+
namespace infinilm::layers::attention::backends {
8+
9+
FlashDecodeAttentionImpl::FlashDecodeAttentionImpl(size_t num_heads,
10+
size_t head_size,
11+
float scale,
12+
size_t num_kv_heads,
13+
size_t layer_idx)
14+
: num_heads_(num_heads),
15+
head_size_(head_size),
16+
scale_(scale),
17+
num_kv_heads_(num_kv_heads),
18+
layer_idx_(layer_idx),
19+
head_dim_(head_size) {
20+
21+
const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
22+
if (!infinilm_config.model_config) {
23+
throw std::runtime_error("infinilm::layers::attention::backends::FlashDecodeAttentionImpl: model_config is null");
24+
}
25+
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
26+
}
27+
28+
infinicore::Tensor FlashDecodeAttentionImpl::forward(const AttentionLayer &layer,
29+
const infinicore::Tensor &query,
30+
const infinicore::Tensor &key,
31+
const infinicore::Tensor &value,
32+
infinicore::Tensor &kv_cache,
33+
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
34+
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
35+
auto input_offsets = attn_metadata.input_offsets;
36+
auto block_tables = attn_metadata.block_tables;
37+
auto slot_mapping = attn_metadata.slot_mapping;
38+
auto cu_seqlens = attn_metadata.cu_seqlens;
39+
40+
ASSERT(block_tables.has_value());
41+
ASSERT(slot_mapping.has_value());
42+
43+
// 1. update paged kv cache
44+
auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value());
45+
46+
size_t seq_len = query->shape()[0];
47+
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
48+
49+
// 2. Compute attention
50+
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
51+
if (is_prefill) {
52+
infinicore::op::paged_attention_prefill_(
53+
attn_output,
54+
query,
55+
k_total->permute({0, 2, 1, 3}),
56+
v_total->permute({0, 2, 1, 3}),
57+
block_tables.value(),
58+
total_sequence_lengths.value(),
59+
input_offsets.value(),
60+
std::nullopt,
61+
scale_);
62+
} else {
63+
auto q_for_fa = query->view({seq_len, 1, num_heads_, head_dim_});
64+
auto attn_out_4d = infinicore::op::mha_kvcache(
65+
q_for_fa,
66+
k_total, // [num_blocks, block_size, num_kv_heads, head_dim]
67+
v_total,
68+
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
69+
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
70+
std::nullopt,
71+
scale_);
72+
attn_output = attn_out_4d->view({seq_len, num_heads_, head_dim_});
73+
}
74+
attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_});
75+
return attn_output;
76+
}
77+
78+
std::tuple<infinicore::Tensor, infinicore::Tensor> FlashDecodeAttentionImpl::do_kv_cache_update(const AttentionLayer &layer,
79+
const infinicore::Tensor key,
80+
const infinicore::Tensor value,
81+
infinicore::Tensor &kv_cache,
82+
const infinicore::Tensor slot_mapping) const {
83+
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
84+
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
85+
infinicore::op::paged_caching_(
86+
k_cache_layer->permute({0, 2, 1, 3}),
87+
v_cache_layer->permute({0, 2, 1, 3}),
88+
key,
89+
value,
90+
slot_mapping);
91+
92+
return {k_cache_layer, v_cache_layer};
93+
}
94+
} // namespace infinilm::layers::attention::backends
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include "../../../global_state/global_state.hpp"
4+
#include "infinicore/tensor.hpp"
5+
#include <tuple>
6+
7+
namespace infinilm::layers::attention {
8+
class AttentionLayer;
9+
}
10+
11+
namespace infinilm::layers::attention::backends {
12+
13+
class FlashDecodeAttentionImpl {
14+
public:
15+
FlashDecodeAttentionImpl(size_t num_heads,
16+
size_t head_size,
17+
float scale,
18+
size_t num_kv_heads,
19+
size_t layer_idx);
20+
21+
infinicore::Tensor forward(const AttentionLayer &layer,
22+
const infinicore::Tensor &query,
23+
const infinicore::Tensor &key,
24+
const infinicore::Tensor &value,
25+
infinicore::Tensor &kv_cache,
26+
const infinilm::global_state::AttentionMetadata &attn_metadata) const;
27+
28+
std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const AttentionLayer &layer,
29+
const infinicore::Tensor key,
30+
const infinicore::Tensor value,
31+
infinicore::Tensor &kv_cache,
32+
const infinicore::Tensor slot_mapping) const;
33+
34+
private:
35+
size_t num_heads_;
36+
size_t head_size_;
37+
float scale_;
38+
size_t num_kv_heads_;
39+
size_t layer_idx_;
40+
size_t head_dim_; // Note: head_dim equals to head_size
41+
size_t max_position_embeddings_;
42+
};
43+
} // namespace infinilm::layers::attention::backends
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "flash_prefill_attn.hpp"
2+
3+
#include "../../../utils.hpp"
4+
#include "infinicore/ops.hpp"
5+
#include "infinicore/ops/mha_varlen.hpp"
6+
7+
namespace infinilm::layers::attention::backends {
8+
9+
FlashPrefillAttentionImpl::FlashPrefillAttentionImpl(size_t num_heads,
10+
size_t head_size,
11+
float scale,
12+
size_t num_kv_heads,
13+
size_t layer_idx)
14+
: num_heads_(num_heads),
15+
head_size_(head_size),
16+
scale_(scale),
17+
num_kv_heads_(num_kv_heads),
18+
layer_idx_(layer_idx),
19+
head_dim_(head_size) {
20+
21+
const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
22+
if (!infinilm_config.model_config) {
23+
throw std::runtime_error("infinilm::layers::attention::backends::FlashPrefillAttentionImpl: model_config is null");
24+
}
25+
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
26+
}
27+
28+
infinicore::Tensor FlashPrefillAttentionImpl::forward(const AttentionLayer &layer,
29+
const infinicore::Tensor &query,
30+
const infinicore::Tensor &key,
31+
const infinicore::Tensor &value,
32+
infinicore::Tensor &kv_cache,
33+
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
34+
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
35+
auto input_offsets = attn_metadata.input_offsets;
36+
auto block_tables = attn_metadata.block_tables;
37+
auto slot_mapping = attn_metadata.slot_mapping;
38+
auto cu_seqlens = attn_metadata.cu_seqlens;
39+
40+
ASSERT(block_tables.has_value());
41+
ASSERT(slot_mapping.has_value());
42+
43+
// 1. update paged kv cache
44+
auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value());
45+
46+
size_t seq_len = query->shape()[0];
47+
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
48+
49+
// 2. Compute attention
50+
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
51+
if (is_prefill) {
52+
infinicore::op::mha_varlen_(
53+
attn_output,
54+
query,
55+
k_total,
56+
v_total,
57+
input_offsets.value(),
58+
cu_seqlens.value(),
59+
block_tables.value(),
60+
max_position_embeddings_,
61+
max_position_embeddings_,
62+
std::nullopt,
63+
scale_);
64+
} else {
65+
infinicore::op::paged_attention_(
66+
attn_output,
67+
query,
68+
k_total->permute({0, 2, 1, 3}),
69+
v_total->permute({0, 2, 1, 3}),
70+
block_tables.value(),
71+
total_sequence_lengths.value(),
72+
std::nullopt,
73+
scale_);
74+
}
75+
attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_});
76+
return attn_output;
77+
}
78+
79+
std::tuple<infinicore::Tensor, infinicore::Tensor> FlashPrefillAttentionImpl::do_kv_cache_update(const AttentionLayer &layer,
80+
const infinicore::Tensor key,
81+
const infinicore::Tensor value,
82+
infinicore::Tensor &kv_cache,
83+
const infinicore::Tensor slot_mapping) const {
84+
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
85+
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
86+
infinicore::op::paged_caching_(
87+
k_cache_layer->permute({0, 2, 1, 3}),
88+
v_cache_layer->permute({0, 2, 1, 3}),
89+
key,
90+
value,
91+
slot_mapping);
92+
93+
return {k_cache_layer, v_cache_layer};
94+
}
95+
} // namespace infinilm::layers::attention::backends
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include "../../../global_state/global_state.hpp"
4+
#include "infinicore/tensor.hpp"
5+
#include <tuple>
6+
7+
namespace infinilm::layers::attention {
8+
class AttentionLayer;
9+
}
10+
11+
namespace infinilm::layers::attention::backends {
12+
13+
class FlashPrefillAttentionImpl {
14+
public:
15+
FlashPrefillAttentionImpl(size_t num_heads,
16+
size_t head_size,
17+
float scale,
18+
size_t num_kv_heads,
19+
size_t layer_idx);
20+
21+
infinicore::Tensor forward(const AttentionLayer &layer,
22+
const infinicore::Tensor &query,
23+
const infinicore::Tensor &key,
24+
const infinicore::Tensor &value,
25+
infinicore::Tensor &kv_cache,
26+
const infinilm::global_state::AttentionMetadata &attn_metadata) const;
27+
28+
std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const AttentionLayer &layer,
29+
const infinicore::Tensor key,
30+
const infinicore::Tensor value,
31+
infinicore::Tensor &kv_cache,
32+
const infinicore::Tensor slot_mapping) const;
33+
34+
private:
35+
size_t num_heads_;
36+
size_t head_size_;
37+
float scale_;
38+
size_t num_kv_heads_;
39+
size_t layer_idx_;
40+
size_t head_dim_; // Note: head_dim equals to head_size
41+
size_t max_position_embeddings_;
42+
};
43+
} // namespace infinilm::layers::attention::backends

csrc/models/infinilm_model.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,27 @@ std::vector<infinicore::Tensor> InfinilmModel::default_allocate_kv_cache_tensors
6262
case backends::AttentionBackend::FLASH_ATTN: {
6363
;
6464
}
65+
case backends::AttentionBackend::FLASH_PREFILL:
66+
case backends::AttentionBackend::FLASH_DECODE: {
67+
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
68+
if (nullptr == paged_kv_cache_config) {
69+
throw std::runtime_error(
70+
"infinilm::InfinilmModel::default_allocate_kv_cache_tensors: invalid paged kv cache config type");
71+
}
72+
kv_cache_vec.reserve(num_hidden_layers);
73+
74+
for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) {
75+
auto kv_cache = cache::PagedKVCache::create_layer_kv_cache(
76+
head_dim,
77+
head_dim,
78+
num_key_value_heads,
79+
num_key_value_heads,
80+
dtype,
81+
*paged_kv_cache_config);
82+
kv_cache_vec.push_back(kv_cache);
83+
}
84+
break;
85+
}
6586
case backends::AttentionBackend::PAGED_ATTN: {
6687
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
6788
if (nullptr == paged_kv_cache_config) {

0 commit comments

Comments
 (0)