Skip to content

Commit 63233f9

Browse files
committed
Merge branch 'main' of github.com:InfiniTensor/InfiniCore into w4a16
2 parents f91d297 + 90cb1b5 commit 63233f9

19 files changed

Lines changed: 267 additions & 91 deletions

File tree

src/infiniccl/moore/infiniccl_moore.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) {
2323
return mcclFloat;
2424
case INFINI_DTYPE_F16:
2525
return mcclHalf;
26+
27+
#if MARCH_TYPE == 310
28+
case INFINI_DTYPE_BF16:
29+
return mcclBfloat16;
30+
#endif
31+
2632
default:
2733
std::abort();
2834
return mcclHalf;
@@ -83,9 +89,16 @@ infiniStatus_t allReduce(
8389
infinicclComm_t comm,
8490
infinirtStream_t stream) {
8591

86-
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
87-
return INFINI_STATUS_BAD_PARAM;
88-
}
92+
#if MARCH_TYPE == 310
93+
CHECK_DTYPE(datatype,
94+
INFINI_DTYPE_F32,
95+
INFINI_DTYPE_F16,
96+
INFINI_DTYPE_BF16);
97+
#else
98+
CHECK_DTYPE(datatype,
99+
INFINI_DTYPE_F32,
100+
INFINI_DTYPE_F16);
101+
#endif
89102

90103
CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype),
91104
getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream)));

src/infinicore/nn/linear.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Linear::Linear(size_t in_features, size_t out_features,
131131
this->register_parameter("qweight", weight_);
132132
weight_zeros_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
133133
this->register_parameter("qzeros", weight_zeros_);
134-
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::F16, device);
134+
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, dtype_, device);
135135
this->register_parameter("scales", weight_scale_);
136136
if (bias) {
137137
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));

src/infiniop/ops/causal_softmax/operator.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
7272
#ifdef ENABLE_MOORE_API
7373
CREATE(INFINI_DEVICE_MOORE, moore)
7474
#endif
75+
default:
76+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
7577
}
76-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
78+
#undef CREATE
7779
}
7880

7981
__C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) {
@@ -117,8 +119,10 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
117119
#ifdef ENABLE_MOORE_API
118120
GET(INFINI_DEVICE_MOORE, moore)
119121
#endif
122+
default:
123+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
120124
}
121-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
125+
#undef GET
122126
}
123127

124128
__C infiniStatus_t infiniopCausalSoftmax(
@@ -167,8 +171,10 @@ __C infiniStatus_t infiniopCausalSoftmax(
167171
#ifdef ENABLE_MOORE_API
168172
CALCULATE(INFINI_DEVICE_MOORE, moore)
169173
#endif
174+
default:
175+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
170176
}
171-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
177+
#undef CALCULATE
172178
}
173179

174180
__C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) {
@@ -212,6 +218,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
212218
#ifdef ENABLE_MOORE_API
213219
DESTROY(INFINI_DEVICE_MOORE, moore)
214220
#endif
221+
default:
222+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
215223
}
216-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
224+
#undef DESTROY
217225
}

src/infiniop/ops/clip/operator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
9191
#ifdef ENABLE_KUNLUN_API
9292
GET(INFINI_DEVICE_KUNLUN, kunlun)
9393
#endif
94+
default:
95+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
9496
}
9597

9698
#undef GET
97-
98-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
9999
}
100100

101101
__C infiniStatus_t infiniopClip(

src/infiniop/ops/logsoftmax/operator.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
5151
#ifdef ENABLE_ASCEND_API
5252
// CREATE(INFINI_DEVICE_ASCEND, ascend)
5353
#endif
54+
default:
55+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
5456
}
55-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
5657
}
5758

5859
__C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) {
@@ -84,8 +85,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
8485
#ifdef ENABLE_ASCEND_API
8586
// GET(INFINI_DEVICE_ASCEND, ascend)
8687
#endif
88+
default:
89+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
8790
}
88-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
8991
}
9092

9193
__C infiniStatus_t infiniopLogSoftmax(
@@ -122,8 +124,9 @@ __C infiniStatus_t infiniopLogSoftmax(
122124
#ifdef ENABLE_ASCEND_API
123125
// CALCULATE(INFINI_DEVICE_ASCEND, ascend)
124126
#endif
127+
default:
128+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
125129
}
126-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
127130
}
128131

129132
__C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) {
@@ -155,6 +158,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
155158
#ifdef ENABLE_ASCEND_API
156159
// DESTROY(INFINI_DEVICE_ASCEND, ascend)
157160
#endif
161+
default:
162+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
158163
}
159-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
160164
}

src/infiniop/ops/paged_caching/cuda/kernel.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
3838
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
3939
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
4040
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
41-
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool
41+
const ptrdiff_t v_cache_block_stride, // Stride between blocks in the V cache pool
42+
const ptrdiff_t k_cache_head_stride, // Stride between heads in the K cache pool
43+
const ptrdiff_t v_cache_head_stride, // Stride between heads in the V cache pool
44+
const ptrdiff_t k_cache_slot_stride, // Stride between block slots in the K cache pool
45+
const ptrdiff_t v_cache_slot_stride // Stride between block slots in the V cache pool
4246
) {
4347
//================================================================================
4448
// 1. Identify Work Unit & Calculate Addresses
@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
6670

6771
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
6872
// We point to the beginning of the memory region for this token's slot.
69-
const ptrdiff_t cache_head_stride = block_size * head_size;
70-
7173
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
72-
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
74+
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
7375

7476
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
75-
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
77+
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
7678

7779
//================================================================================
7880
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)

src/infiniop/ops/paged_caching/info.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class PagedCachingInfo {
2626
ptrdiff_t v_src_stride;
2727
ptrdiff_t k_cache_block_stride;
2828
ptrdiff_t v_cache_block_stride;
29+
ptrdiff_t k_cache_head_stride;
30+
ptrdiff_t v_cache_head_stride;
31+
ptrdiff_t k_cache_slot_stride;
32+
ptrdiff_t v_cache_slot_stride;
2933

3034
static utils::Result<PagedCachingInfo> create(
3135
infiniopTensorDescriptor_t k_cache_desc,
@@ -63,6 +67,10 @@ class PagedCachingInfo {
6367
ptrdiff_t v_src_stride = v_desc->stride(0);
6468
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
6569
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
70+
ptrdiff_t k_cache_head_stride = k_cache_desc->stride(1);
71+
ptrdiff_t v_cache_head_stride = v_cache_desc->stride(1);
72+
ptrdiff_t k_cache_slot_stride = k_cache_desc->stride(2);
73+
ptrdiff_t v_cache_slot_stride = v_cache_desc->stride(2);
6674

6775
return utils::Result<PagedCachingInfo>(PagedCachingInfo{
6876
dtype,
@@ -73,7 +81,11 @@ class PagedCachingInfo {
7381
k_src_stride,
7482
v_src_stride,
7583
k_cache_block_stride,
76-
v_cache_block_stride});
84+
v_cache_block_stride,
85+
k_cache_head_stride,
86+
v_cache_head_stride,
87+
k_cache_slot_stride,
88+
v_cache_slot_stride});
7789
}
7890
};
7991

src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
1010
const int64_t *slot_mapping,
1111
const size_t head_size, const size_t block_size,
1212
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
13-
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
13+
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
14+
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
15+
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
1416
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
1517
k_cache, v_cache, k, v, slot_mapping, head_size,
16-
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
18+
block_size, k_src_stride, v_src_stride,
19+
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
1720
}
1821

1922
namespace op::paged_caching::metax {
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
5962
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
6063
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
6164
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
65+
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
66+
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
6267
hcStream_t stream) {
6368

6469
// Grid dimension is 1D, with one block per token, as we decided.
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
8388
k_src_stride,
8489
v_src_stride,
8590
k_cache_block_stride,
86-
v_cache_block_stride);
91+
v_cache_block_stride,
92+
k_cache_head_stride,
93+
v_cache_head_stride,
94+
k_cache_slot_stride,
95+
v_cache_slot_stride);
8796
} else if (dtype == INFINI_DTYPE_BF16) {
8897
pagedCaching<cuda_bfloat16, NUM_THREADS>
8998
<<<grid, block, shared_mem_size, stream>>>(
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
97106
k_src_stride,
98107
v_src_stride,
99108
k_cache_block_stride,
100-
v_cache_block_stride);
109+
v_cache_block_stride,
110+
k_cache_head_stride,
111+
v_cache_head_stride,
112+
k_cache_slot_stride,
113+
v_cache_slot_stride);
101114
} else if (dtype == INFINI_DTYPE_F32) {
102115
pagedCaching<float, NUM_THREADS>
103116
<<<grid, block, shared_mem_size, stream>>>(
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
111124
k_src_stride,
112125
v_src_stride,
113126
k_cache_block_stride,
114-
v_cache_block_stride);
127+
v_cache_block_stride,
128+
k_cache_head_stride,
129+
v_cache_head_stride,
130+
k_cache_slot_stride,
131+
v_cache_slot_stride);
115132
} else {
116133
return INFINI_STATUS_BAD_TENSOR_DTYPE;
117134
}
@@ -138,13 +155,17 @@ infiniStatus_t Descriptor::calculate(
138155
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
139156
_info.k_src_stride, _info.v_src_stride,
140157
_info.k_cache_block_stride, _info.v_cache_block_stride,
158+
_info.k_cache_head_stride, _info.v_cache_head_stride,
159+
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
141160
stream);
142161
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
143162
launchKernel<METAX_BLOCK_SIZE_512>(
144163
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
145164
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
146165
_info.k_src_stride, _info.v_src_stride,
147166
_info.k_cache_block_stride, _info.v_cache_block_stride,
167+
_info.k_cache_head_stride, _info.v_cache_head_stride,
168+
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
148169
stream);
149170
} else {
150171
// If the device supports fewer threads, return an error.

src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
1010
const int64_t *slot_mapping,
1111
const size_t head_size, const size_t block_size,
1212
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
13-
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
13+
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
14+
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
15+
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
1416
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
1517
k_cache, v_cache, k, v, slot_mapping, head_size,
16-
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
18+
block_size, k_src_stride, v_src_stride,
19+
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
1720
}
1821

1922
namespace op::paged_caching::moore {
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
5962
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
6063
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
6164
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
65+
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
66+
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
6267
musaStream_t stream) {
6368

6469
// Grid dimension is 1D, with one block per token, as we decided.
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
8388
k_src_stride,
8489
v_src_stride,
8590
k_cache_block_stride,
86-
v_cache_block_stride);
91+
v_cache_block_stride,
92+
k_cache_head_stride,
93+
v_cache_head_stride,
94+
k_cache_slot_stride,
95+
v_cache_slot_stride);
8796
} else if (dtype == INFINI_DTYPE_BF16) {
8897
pagedCaching<__mt_bfloat16, NUM_THREADS>
8998
<<<grid, block, shared_mem_size, stream>>>(
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
97106
k_src_stride,
98107
v_src_stride,
99108
k_cache_block_stride,
100-
v_cache_block_stride);
109+
v_cache_block_stride,
110+
k_cache_head_stride,
111+
v_cache_head_stride,
112+
k_cache_slot_stride,
113+
v_cache_slot_stride);
101114
} else if (dtype == INFINI_DTYPE_F32) {
102115
pagedCaching<float, NUM_THREADS>
103116
<<<grid, block, shared_mem_size, stream>>>(
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
111124
k_src_stride,
112125
v_src_stride,
113126
k_cache_block_stride,
114-
v_cache_block_stride);
127+
v_cache_block_stride,
128+
k_cache_head_stride,
129+
v_cache_head_stride,
130+
k_cache_slot_stride,
131+
v_cache_slot_stride);
115132
} else {
116133
return INFINI_STATUS_BAD_TENSOR_DTYPE;
117134
}
@@ -137,13 +154,17 @@ infiniStatus_t Descriptor::calculate(
137154
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
138155
_info.k_src_stride, _info.v_src_stride,
139156
_info.k_cache_block_stride, _info.v_cache_block_stride,
157+
_info.k_cache_head_stride, _info.v_cache_head_stride,
158+
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
140159
stream);
141160
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
142161
launchKernel<MOORE_BLOCK_SIZE_512>(
143162
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
144163
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
145164
_info.k_src_stride, _info.v_src_stride,
146165
_info.k_cache_block_stride, _info.v_cache_block_stride,
166+
_info.k_cache_head_stride, _info.v_cache_head_stride,
167+
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
147168
stream);
148169
} else {
149170
// If the GPU is older and supports fewer threads, return an error.

0 commit comments

Comments
 (0)