Skip to content

Commit 96a760a

Browse files
committed
draft commit
1 parent bc0b7e5 commit 96a760a

14 files changed

Lines changed: 320 additions & 13 deletions

File tree

csrc/backends/attention_backends.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ enum class AttentionBackend {
1313
STATIC_ATTN,
1414
PAGED_ATTN,
1515
FLASH_ATTN,
16+
HYBRID_ATTN,
1617
FLASHINFER,
1718
Default = STATIC_ATTN
1819
};
@@ -25,6 +26,8 @@ inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) {
2526
return os << "AttentionBackend::PAGED_ATTN";
2627
case AttentionBackend::FLASH_ATTN:
2728
return os << "AttentionBackend::FLASH_ATTN";
29+
case AttentionBackend::HYBRID_ATTN:
30+
return os << "AttentionBackend::HYBRID_ATTN";
2831
case AttentionBackend::FLASHINFER:
2932
return os << "AttentionBackend::FLASHINFER";
3033
default:
@@ -46,12 +49,15 @@ inline AttentionBackend parse_attention_backend(const std::string &backend) {
4649
if (backend == "flash-attn") {
4750
return AttentionBackend::FLASH_ATTN;
4851
}
52+
if (backend == "hybrid-attn") {
53+
return AttentionBackend::HYBRID_ATTN;
54+
}
4955
if (backend == "flashinfer") {
5056
return AttentionBackend::FLASHINFER;
5157
}
5258

5359
throw std::invalid_argument(
54-
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer");
60+
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, hybrid-attn, flashinfer");
5561
}
5662

5763
} // namespace infinilm::backends

csrc/layers/attention/backends/attention_layer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ AttentionLayer::AttentionLayer(size_t num_heads,
2020
case ::infinilm::backends::AttentionBackend::FLASH_ATTN:
2121
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
2222
break;
23+
case ::infinilm::backends::AttentionBackend::HYBRID_ATTN:
24+
attn_backend_impl_ = std::make_shared<backends::HybridAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
25+
break;
2326
default:
2427
throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend");
2528
}

csrc/layers/attention/backends/attention_layer.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
#include "../../../backends/attention_backends.hpp"
44
#include "../../../global_state/global_state.hpp"
55
#include "flash_attn.hpp"
6+
#include "hybrid_attn.hpp"
67
#include "infinicore/tensor.hpp"
78
#include "paged_attn.hpp"
89
#include "static_attn.hpp"
910
#include <memory>
1011
#include <variant>
1112

1213
namespace infinilm::layers::attention {
13-
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>>;
14+
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>, std::shared_ptr<backends::HybridAttentionImpl>>;
1415

1516
/**
1617
* @brief Attention layer.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "hybrid_attn.hpp"
2+
3+
#include "../../../utils.hpp"
4+
#include "infinicore/ops/mha_varlen.hpp"
5+
#include "infinicore/ops/paged_attention.hpp"
6+
#include "infinicore/ops/paged_caching.hpp"
7+
8+
#include <stdexcept>
9+
10+
namespace infinilm::layers::attention::backends {
11+
12+
HybridAttentionImpl::HybridAttentionImpl(size_t num_heads,
13+
size_t head_size,
14+
float scale,
15+
size_t num_kv_heads,
16+
size_t layer_idx)
17+
: num_heads_(num_heads),
18+
scale_(scale),
19+
head_dim_(head_size) {
20+
21+
const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
22+
if (!infinilm_config.model_config) {
23+
throw std::runtime_error("infinilm::layers::attention::backends::HybridAttentionImpl: model_config is null");
24+
}
25+
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
26+
}
27+
28+
infinicore::Tensor HybridAttentionImpl::forward(const AttentionLayer &layer,
29+
const infinicore::Tensor &query,
30+
const infinicore::Tensor &key,
31+
const infinicore::Tensor &value,
32+
infinicore::Tensor &kv_cache,
33+
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
34+
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
35+
auto input_offsets = attn_metadata.input_offsets;
36+
auto block_tables = attn_metadata.block_tables;
37+
auto slot_mapping = attn_metadata.slot_mapping;
38+
auto cu_seqlens = attn_metadata.cu_seqlens;
39+
40+
ASSERT(total_sequence_lengths.has_value());
41+
ASSERT(input_offsets.has_value());
42+
ASSERT(block_tables.has_value());
43+
ASSERT(slot_mapping.has_value());
44+
ASSERT(cu_seqlens.has_value());
45+
46+
auto [k_total, v_total] = do_kv_cache_update(key, value, kv_cache, slot_mapping.value());
47+
48+
size_t seq_len = query->shape()[0];
49+
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
50+
51+
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
52+
if (is_prefill) {
53+
infinicore::op::mha_varlen_(
54+
attn_output,
55+
query,
56+
key,
57+
value,
58+
input_offsets.value(),
59+
cu_seqlens.value(),
60+
block_tables.value(),
61+
max_position_embeddings_,
62+
max_position_embeddings_,
63+
std::nullopt,
64+
scale_);
65+
} else {
66+
infinicore::op::paged_attention_(
67+
attn_output,
68+
query,
69+
k_total,
70+
v_total,
71+
block_tables.value(),
72+
total_sequence_lengths.value(),
73+
std::nullopt,
74+
scale_);
75+
}
76+
return attn_output->view({1, seq_len, num_heads_ * head_dim_});
77+
}
78+
79+
std::tuple<infinicore::Tensor, infinicore::Tensor> HybridAttentionImpl::do_kv_cache_update(const infinicore::Tensor key,
80+
const infinicore::Tensor value,
81+
infinicore::Tensor &kv_cache,
82+
const infinicore::Tensor slot_mapping) const {
83+
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
84+
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
85+
infinicore::op::paged_caching_(
86+
k_cache_layer,
87+
v_cache_layer,
88+
key,
89+
value,
90+
slot_mapping);
91+
92+
return {k_cache_layer, v_cache_layer};
93+
}
94+
95+
} // namespace infinilm::layers::attention::backends
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include "../../../global_state/global_state.hpp"
4+
#include "infinicore/tensor.hpp"
5+
#include <tuple>
6+
7+
namespace infinilm::layers::attention {
8+
class AttentionLayer;
9+
}
10+
11+
namespace infinilm::layers::attention::backends {
12+
13+
class HybridAttentionImpl {
14+
public:
15+
HybridAttentionImpl(size_t num_heads,
16+
size_t head_size,
17+
float scale,
18+
size_t num_kv_heads,
19+
size_t layer_idx);
20+
21+
infinicore::Tensor forward(const AttentionLayer &layer,
22+
const infinicore::Tensor &query,
23+
const infinicore::Tensor &key,
24+
const infinicore::Tensor &value,
25+
infinicore::Tensor &kv_cache,
26+
const infinilm::global_state::AttentionMetadata &attn_metadata) const;
27+
28+
std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const infinicore::Tensor key,
29+
const infinicore::Tensor value,
30+
infinicore::Tensor &kv_cache,
31+
const infinicore::Tensor slot_mapping) const;
32+
33+
private:
34+
size_t num_heads_;
35+
float scale_;
36+
size_t head_dim_;
37+
size_t max_position_embeddings_;
38+
};
39+
40+
} // namespace infinilm::layers::attention::backends

csrc/models/infinilm_model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ std::vector<infinicore::Tensor> InfinilmModel::default_allocate_kv_cache_tensors
6060
case backends::AttentionBackend::FLASH_ATTN: {
6161
;
6262
}
63+
case backends::AttentionBackend::HYBRID_ATTN:
6364
case backends::AttentionBackend::PAGED_ATTN: {
6465
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
6566
if (nullptr == paged_kv_cache_config) {

csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
namespace infinilm::models::minicpm_sala {
99

1010
std::vector<infinicore::Tensor> minicpm_sala_allocate_kv_cache_tensors(const cache::CacheConfig *cache_config,
11-
const std::shared_ptr<infinilm::config::ModelConfig> &text_config,
12-
const backends::AttentionBackend &attention_backend) {
11+
const std::shared_ptr<infinilm::config::ModelConfig> &text_config,
12+
const backends::AttentionBackend &attention_backend) {
1313
if (nullptr == cache_config) {
1414
return {};
1515
}
@@ -58,6 +58,7 @@ std::vector<infinicore::Tensor> minicpm_sala_allocate_kv_cache_tensors(const cac
5858
}
5959
break;
6060
}
61+
case backends::AttentionBackend::HYBRID_ATTN:
6162
case backends::AttentionBackend::PAGED_ATTN: {
6263
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
6364
if (nullptr == paged_kv_cache_config) {

csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ std::vector<infinicore::Tensor> qwen3_next_allocate_kv_cache_tensors(
5757
case backends::AttentionBackend::FLASH_ATTN: {
5858
;
5959
}
60+
case backends::AttentionBackend::HYBRID_ATTN:
6061
case backends::AttentionBackend::PAGED_ATTN: {
6162
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
6263
if (nullptr == paged_kv_cache_config) {

docs/hybrid_attn_iluvatar.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Hybrid Attention on Iluvatar
2+
3+
本文档记录 Iluvatar 平台上的 `hybrid-attn` 路径:Prefill 使用 FlashAttention-2 varlen,Decode 使用 InfiniCore 原生 PagedAttention。
4+
5+
## 改了什么
6+
7+
### InfiniCore
8+
9+
- Iluvatar 被作为 CUDA-compatible ATen device 处理,用于复用 ATen/CUDA tensor 和 stream guard。
10+
- Iluvatar 的 FlashAttention-2 调用走全局 `flash_attn_2_cuda` ABI。
11+
- 适配 Iluvatar 当前 `mha_varlen_fwd` 尾部参数。
12+
- `mha_varlen``mha_kvcache` 的 FlashAttention 路径使用当前 InfiniCore stream。
13+
- Iluvatar + `--flash-attn` 构建时,`libinfinicore_cpp_api.so` 链接 `flash_attn_2_cuda*.so` 并写入 rpath。
14+
- 构建时同步 PyTorch 的 `_GLIBCXX_USE_CXX11_ABI`,避免 Python 扩展加载时 ABI 符号不匹配。
15+
16+
### InfiniLM
17+
18+
- 新增 `hybrid-attn` attention backend。
19+
- 新增独立 `HybridAttentionImpl`,不改变纯 `flash-attn` 语义。
20+
- `hybrid-attn` 的执行路径:
21+
- Prefill:使用 FA2 varlen,输入为本轮 dense `query/key/value`
22+
- Decode:使用原生 `paged_attention_`,输入为 paged KV cache。
23+
- `hybrid-attn` 使用 paged KV cache 分配。
24+
- Python CLI/API 层会把 `hybrid-attn` 归一化为 paged cache 路径,避免 cache 类型误配。
25+
26+
## 怎么使用
27+
28+
以下命令以 `/data-aisoft/qyq_models/Qwen2.5-3B-Instruct` 为例。
29+
30+
### 1. 环境变量
31+
32+
FA2 的 `.so` 路径可以直接通过 Python 获取:
33+
34+
```bash
35+
export FLASH_ATTN_2_CUDA_SO=$(python3 -c 'import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)')
36+
export LD_LIBRARY_PATH=/root/.infini/lib:/usr/local/corex/lib64:/usr/local/corex/lib64/python3/dist-packages/torch/lib:/usr/local/corex/lib64/python3/dist-packages:$LD_LIBRARY_PATH
37+
export PYTHONPATH=/home/zx/InfiniLM/python:/home/zx/InfiniCore/python:/usr/local/corex/lib64/python3/dist-packages:$PYTHONPATH
38+
```
39+
40+
### 2. 构建 InfiniCore
41+
42+
```bash
43+
cd /home/zx/InfiniCore
44+
xmake f --iluvatar-gpu=y --aten=y --flash-attn=/usr/local/corex/lib64/python3/dist-packages
45+
xmake build infinicore_cpp_api
46+
xmake build _infinicore
47+
xmake install -o /root/.infini infinicore_cpp_api
48+
xmake install -o /root/.infini _infinicore
49+
```
50+
51+
同步本地 Python 包中的 InfiniCore 扩展:
52+
53+
```bash
54+
cp -f /root/.infini/lib/libinfinicore_cpp_api.so /home/zx/InfiniCore/python/infinicore/lib/libinfinicore_cpp_api.so
55+
cp -f /root/.infini/lib/_infinicore.cpython-310-x86_64-linux-gnu.so /home/zx/InfiniCore/python/infinicore/lib/_infinicore.cpython-310-x86_64-linux-gnu.so
56+
```
57+
58+
可选检查 `libinfinicore_cpp_api.so` 是否已经链接 FA2:
59+
60+
```bash
61+
readelf -d /root/.infini/lib/libinfinicore_cpp_api.so | grep flash_attn_2_cuda
62+
```
63+
64+
预期能看到类似:
65+
66+
```text
67+
NEEDED Shared library: [/usr/local/corex/lib64/python3/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so]
68+
```
69+
70+
### 3. 构建 InfiniLM
71+
72+
```bash
73+
cd /home/zx/InfiniLM
74+
xmake build _infinilm
75+
xmake install _infinilm
76+
```
77+
78+
### 4. 运行 hybrid-attn 推理
79+
80+
```bash
81+
cd /home/zx/InfiniLM
82+
python3 examples/test_infer.py \
83+
--model /data-aisoft/qyq_models/Qwen2.5-3B-Instruct \
84+
--device iluvatar \
85+
--enable-paged-attn \
86+
--attn hybrid-attn \
87+
--batch-size 1 \
88+
--max-new-tokens 4 \
89+
--prompt "你好" \
90+
--temperature 0.0 \
91+
--top-k 1
92+
```
93+
94+
说明:当前 CLI/API 会将 `hybrid-attn` 归一化到 paged cache 路径;命令中保留 `--enable-paged-attn` 是为了显式表达运行条件。
95+
96+
## Qwen2.5 运行结果
97+
98+
验证环境:
99+
100+
- Platform:Iluvatar
101+
- Model:`/data-aisoft/qyq_models/Qwen2.5-3B-Instruct`
102+
- Attention backend:`hybrid-attn`
103+
- Batch size:1
104+
- Max new tokens:4
105+
- Prompt:`你好`
106+
107+
构建验证:
108+
109+
```text
110+
xmake build infinicore_cpp_api # passed
111+
xmake build _infinicore # passed
112+
xmake build _infinilm # passed
113+
```
114+
115+
推理复现结果:
116+
117+
```text
118+
load weights over! 2431.8737983703613 ms
119+
120+
=================== start generate ====================
121+
Generating: 100%|██████████| 1/1 [00:02<00:00, 2.53s/it]
122+
Resquest 0:
123+
===Query===
124+
<|im_start|>system
125+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
126+
<|im_start|>user
127+
你好<|im_end|>
128+
<|im_start|>assistant
129+
130+
===Response===
131+
""""
132+
133+
total_time: 2582.32 ms
134+
```
135+
136+
## 当前边界
137+
138+
- 当前稳定验证路径是 Iluvatar + Qwen2.5 + FA2 dense prefill + native paged decode。
139+
- Iluvatar 当前 FA2 varlen 不使用 paged KV cache layout 作为 prefill 输入,hybrid prefill 使用本轮 dense `key/value`
140+
- `flash-attn` 仍表示纯 FA 路径;`hybrid-attn` 是单独 backend。

python/infinilm/base_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(self):
105105
# Multimodal parameters
106106
self.image = self.args.image
107107

108+
if self.attn == "hybrid-attn":
109+
self.enable_paged_attn = True
110+
108111
if self.enable_paged_attn and self.attn == "default":
109112
self.attn = "paged-attn"
110113

@@ -119,7 +122,7 @@ def _add_common_args(self):
119122
"--attn",
120123
type=str,
121124
default="default",
122-
choices=["default", "paged-attn", "flash-attn"],
125+
choices=["default", "paged-attn", "flash-attn", "hybrid-attn"],
123126
)
124127
self.parser.add_argument("--enable-graph", action="store_true")
125128
self.parser.add_argument(

0 commit comments

Comments
 (0)