Skip to content

Commit dce9986

Browse files
authored
Merge pull request #1053 from InfiniTensor/issue/1033xmake
Issue/1033 patch aten and fa adaptations
2 parents 8d99a8f + d6e44e8 commit dce9986

102 files changed

Lines changed: 1432 additions & 451 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS]
107107
| `--ali-ppu=[y\|n]` | 是否编译阿里 PPU 接口实现 | n
108108
| `--ninetoothed=[y\|n]` | 是否编译九齿实现 | n
109109
| `--ccl=[y\|n]` | 是否编译 InfiniCCL 通信库接口实现 | n
110+
| `--graph=[y\|n]` | 是否编译 cuda graph 接口实现 | n
110111

111112
##### 手动安装底层库
112113

@@ -154,6 +155,20 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS]
154155
xmake f --ascend-npu=true -cv
155156
```
156157

158+
##### 试验功能 -- 使用flash attention库中的算子
159+
160+
```shell
161+
162+
(1) 在third_party目录拉取cutlass和flash attn库的源码(不需要--recursive)
163+
164+
(2) 设置(1)中cutlass路径的环境变量CUTLASS_ROOT
165+
166+
(3) xmake配置环节额外打开 --aten 开关,并设置 --flash-attn 库位置,例:
167+
xmake f --nv-gpu=y --ccl=y --cuda=$CUDA_HOME --aten=y --flash-attn=<path-to>/InfiniCore/third_party/flash-attention -cv
168+
169+
(4) flash attenion库会伴随infinicore_cpp_api一同编译安装
170+
```
171+
157172
2. 编译安装
158173

159174
默认安装路径为 `$HOME/.infini`

include/infiniccl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ struct InfinicclComm;
1515

1616
typedef struct InfinicclComm *infinicclComm_t;
1717

18-
__C __export infiniStatus_t infinicclCommInitAll(
18+
__INFINI_C __export infiniStatus_t infinicclCommInitAll(
1919
infiniDevice_t device_type,
2020
infinicclComm_t *comms,
2121
int ndevice,
2222
const int *device_ids);
2323

24-
__C __export infiniStatus_t infinicclCommDestroy(infinicclComm_t comm);
24+
__INFINI_C __export infiniStatus_t infinicclCommDestroy(infinicclComm_t comm);
2525

26-
__C __export infiniStatus_t infinicclAllReduce(
26+
__INFINI_C __export infiniStatus_t infinicclAllReduce(
2727
void *sendbuf,
2828
void *recvbuf,
2929
size_t count,

include/infinicore.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
#endif
1111

1212
#ifdef __cplusplus
13-
#define __C extern "C"
13+
#define __INFINI_C extern "C"
1414
#include <cstddef>
1515
#else
16-
#define __C
16+
#define __INFINI_C
1717
#include <stddef.h>
1818
#endif
1919

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifdef ENABLE_ATEN
2+
#pragma once
3+
#include "../context/context.hpp"
4+
#include "../tensor.hpp"
5+
6+
#include <ATen/ATen.h>
7+
8+
#ifdef ENABLE_NVIDIA_API
9+
#include <ATen/cuda/CUDAContext.h>
10+
#include <c10/cuda/CUDAGuard.h>
11+
#endif
12+
13+
namespace infinicore::adaptor {
14+
inline at::ScalarType to_at_dtype(DataType dtype) {
15+
switch (dtype) {
16+
case DataType::F32:
17+
return at::kFloat;
18+
case DataType::F16:
19+
return at::kHalf;
20+
case DataType::BF16:
21+
return at::kBFloat16;
22+
case DataType::I32:
23+
return at::kInt;
24+
case DataType::I64:
25+
return at::kLong;
26+
default:
27+
throw std::runtime_error("Unsupported dtype for ATen");
28+
}
29+
}
30+
31+
inline at::Device to_at_device(const Device &device) {
32+
if (device.getType() == Device::Type::NVIDIA) {
33+
return at::Device(at::kCUDA, device.getIndex());
34+
} else if (device.getType() == Device::Type::CPU) {
35+
return at::Device(at::kCPU);
36+
} else {
37+
throw std::runtime_error("Unsupported device type for ATen");
38+
}
39+
}
40+
41+
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
42+
43+
#ifdef ENABLE_NVIDIA_API
44+
c10::cuda::CUDAStream get_cuda_stream();
45+
#endif
46+
} // namespace infinicore::adaptor
47+
48+
#endif // ENABLE_ATEN
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#ifdef ENABLE_FLASH_ATTN
2+
#pragma once
3+
#include "aten_adaptor.hpp"
4+
5+
namespace flash {
6+
std::vector<at::Tensor>
7+
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
8+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
9+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
10+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
11+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
12+
const float p_dropout,
13+
const float softmax_scale,
14+
bool is_causal,
15+
int window_size_left,
16+
int window_size_right,
17+
const float softcap,
18+
const bool return_softmax,
19+
std::optional<at::Generator> gen_);
20+
21+
std::vector<at::Tensor>
22+
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
23+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
24+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
25+
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
26+
const at::Tensor &cu_seqlens_q, // b+1
27+
const at::Tensor &cu_seqlens_k, // b+1
28+
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
29+
std::optional<const at::Tensor> &leftpad_k_, // batch_size
30+
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
31+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
32+
int max_seqlen_q,
33+
const int max_seqlen_k,
34+
const float p_dropout,
35+
const float softmax_scale,
36+
const bool zero_tensors,
37+
bool is_causal,
38+
int window_size_left,
39+
int window_size_right,
40+
const float softcap,
41+
const bool return_softmax,
42+
std::optional<at::Generator> gen_);
43+
44+
std::vector<at::Tensor>
45+
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
46+
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
47+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
48+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
49+
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
50+
const at::Tensor &softmax_lse, // b x h x seqlen_q
51+
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
52+
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
53+
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
54+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
55+
const float p_dropout, // probability to drop
56+
const float softmax_scale,
57+
const bool is_causal,
58+
int window_size_left,
59+
int window_size_right,
60+
const float softcap,
61+
const bool deterministic,
62+
std::optional<at::Generator> gen_,
63+
std::optional<at::Tensor> &rng_state);
64+
65+
std::vector<at::Tensor>
66+
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
67+
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
68+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
69+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
70+
const at::Tensor &out, // total_q x num_heads x head_size
71+
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
72+
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
73+
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
74+
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
75+
const at::Tensor &cu_seqlens_q, // b+1
76+
const at::Tensor &cu_seqlens_k, // b+1
77+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
78+
const int max_seqlen_q,
79+
const int max_seqlen_k, // max sequence length to choose the kernel
80+
const float p_dropout, // probability to drop
81+
const float softmax_scale,
82+
const bool zero_tensors,
83+
const bool is_causal,
84+
int window_size_left,
85+
int window_size_right,
86+
const float softcap,
87+
const bool deterministic,
88+
std::optional<at::Generator> gen_,
89+
std::optional<at::Tensor> &rng_state);
90+
91+
std::vector<at::Tensor>
92+
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
93+
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
94+
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
95+
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
96+
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
97+
std::optional<const at::Tensor> &seqlens_k_, // batch_size
98+
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
99+
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
100+
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
101+
std::optional<const at::Tensor> &leftpad_k_, // batch_size
102+
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
103+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
104+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
105+
const float softmax_scale,
106+
bool is_causal,
107+
int window_size_left,
108+
int window_size_right,
109+
const float softcap,
110+
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
111+
int num_splits);
112+
113+
} // namespace flash
114+
#endif // ENABLE_FLASH_ATTN
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::op {
8+
9+
INFINICORE_GRAPH_OP_CLASS(
10+
MultiheadAttentionVarlen,
11+
Tensor,
12+
const Tensor &,
13+
const Tensor &,
14+
const Tensor &,
15+
const Tensor &,
16+
const Tensor &,
17+
const Tensor &,
18+
int,
19+
int,
20+
std::optional<Tensor>,
21+
float);
22+
23+
Tensor mha_varlen(const Tensor &q,
24+
const Tensor &k,
25+
const Tensor &v,
26+
const Tensor &cum_seqlens_q,
27+
const Tensor &cum_seqlens_k,
28+
const Tensor &block_table,
29+
int max_seqlen_q,
30+
int max_seqlen_k,
31+
std::optional<Tensor> alibi_slopes,
32+
float scale);
33+
34+
void mha_varlen_(Tensor out,
35+
const Tensor &q,
36+
const Tensor &k,
37+
const Tensor &v,
38+
const Tensor &cum_seqlens_q,
39+
const Tensor &cum_seqlens_k,
40+
const Tensor &block_table,
41+
int max_seqlen_q,
42+
int max_seqlen_k,
43+
std::optional<Tensor> alibi_slopes,
44+
float scale);
45+
46+
} // namespace infinicore::op

include/infiniop/handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ struct InfiniopHandle;
77

88
typedef struct InfiniopHandle *infiniopHandle_t;
99

10-
__C __export infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr);
10+
__INFINI_C __export infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr);
1111

12-
__C __export infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle);
12+
__INFINI_C __export infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle);
1313

1414
#endif

include/infiniop/operator_descriptor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// Base descriptor for all operators
88
struct InfiniopDescriptor;
99

10-
__C __export infiniStatus_t infiniopGetDescriptorDeviceType(const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
11-
__C __export infiniStatus_t infiniopGetDescriptorDeviceId(const struct InfiniopDescriptor *desc_ptr, int *device_id);
10+
__INFINI_C __export infiniStatus_t infiniopGetDescriptorDeviceType(const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
11+
__INFINI_C __export infiniStatus_t infiniopGetDescriptorDeviceId(const struct InfiniopDescriptor *desc_ptr, int *device_id);
1212

1313
#endif //__INFINIOP_OPERATOR_DESCRIPTOR_API_H__

include/infiniop/ops/add.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66
typedef struct InfiniopDescriptor *infiniopAddDescriptor_t;
77

8-
__C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
8+
__INFINI_C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
99
infiniopAddDescriptor_t *desc_ptr,
1010
infiniopTensorDescriptor_t c,
1111
infiniopTensorDescriptor_t a,
1212
infiniopTensorDescriptor_t b);
1313

14-
__C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);
14+
__INFINI_C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);
1515

16-
__C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
16+
__INFINI_C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
1717
void *workspace,
1818
size_t workspace_size,
1919
void *c,
2020
const void *a,
2121
const void *b,
2222
void *stream);
2323

24-
__C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
24+
__INFINI_C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
2525

2626
#endif

include/infiniop/ops/add_rms_norm.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t;
77

8-
__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
8+
__INFINI_C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
99
infiniopHandle_t handle,
1010
infiniopAddRMSNormDescriptor_t *desc_ptr,
1111
infiniopTensorDescriptor_t y_desc,
@@ -15,9 +15,9 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
1515
infiniopTensorDescriptor_t weight_desc,
1616
float epsilon);
1717

18-
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
18+
__INFINI_C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
1919

20-
__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
20+
__INFINI_C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
2121
void *workspace,
2222
size_t workspace_size,
2323
void *y,
@@ -27,6 +27,6 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
2727
const void *weight,
2828
void *stream);
2929

30-
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
30+
__INFINI_C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
3131

3232
#endif

0 commit comments

Comments
 (0)