Skip to content

Commit 9783a39

Browse files
author
zhangyue
committed
feat(ascend): add top-k top-p sampler
1 parent 76094ad commit 9783a39

3 files changed

Lines changed: 423 additions & 0 deletions

File tree

src/base/top_k_top_p_sampler.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_
2+
#define INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_
3+
4+
#include <cassert>
5+
#include <optional>
6+
7+
#include "data_type.h"
8+
#include "operator.h"
9+
#include "tensor.h"
10+
11+
namespace infini::ops {
12+
13+
// `TopKTopPSampler` samples token ids from 2D `logits` after optional rank and
14+
// nucleus filtering. The name and tensor boundary follow vLLM's
15+
// `TopKTopPSampler`; temperature scaling is intentionally handled by callers.
16+
// The optional `k` and `p` tensors may be shaped as `[1]` or `[batch_size]`.
17+
class TopKTopPSampler : public Operator<TopKTopPSampler> {
18+
public:
19+
TopKTopPSampler(const Tensor logits, std::optional<Tensor> k,
20+
std::optional<Tensor> p, Tensor out)
21+
: batch_size_{logits.size(0)},
22+
vocab_size_{logits.size(1)},
23+
dtype_{logits.dtype()} {
24+
assert(logits.ndim() == 2 &&
25+
"`TopKTopPSampler` requires 2D `[batch_size, vocab_size]` logits");
26+
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
27+
dtype_ == DataType::kFloat32 || dtype_ == DataType::kFloat64) &&
28+
"`TopKTopPSampler` requires floating-point logits");
29+
assert(out.ndim() == 1 &&
30+
"`TopKTopPSampler` requires 1D `[batch_size]` output");
31+
assert(out.size(0) == batch_size_ &&
32+
"`TopKTopPSampler` requires output batch size to match logits");
33+
assert(out.dtype() == DataType::kInt32 &&
34+
"`TopKTopPSampler` requires int32 output");
35+
36+
ValidateK(k);
37+
ValidateP(p);
38+
}
39+
40+
virtual void operator()(const Tensor logits, std::optional<Tensor> k,
41+
std::optional<Tensor> p, Tensor out) const = 0;
42+
43+
protected:
44+
void ValidateK(std::optional<Tensor> k) const {
45+
if (!k.has_value()) return;
46+
47+
assert(k->ndim() == 1 &&
48+
"`TopKTopPSampler` requires `k` to be 1D when provided");
49+
assert((k->size(0) == 1 || k->size(0) == batch_size_) &&
50+
"`TopKTopPSampler` requires `k` shape [1] or [batch_size]");
51+
assert((k->dtype() == DataType::kInt32 || k->dtype() == DataType::kInt64) &&
52+
"`TopKTopPSampler` requires int32 or int64 `k`");
53+
}
54+
55+
void ValidateP(std::optional<Tensor> p) const {
56+
if (!p.has_value()) return;
57+
58+
assert(p->ndim() == 1 &&
59+
"`TopKTopPSampler` requires `p` to be 1D when provided");
60+
assert((p->size(0) == 1 || p->size(0) == batch_size_) &&
61+
"`TopKTopPSampler` requires `p` shape [1] or [batch_size]");
62+
assert((p->dtype() == DataType::kFloat16 ||
63+
p->dtype() == DataType::kBFloat16 ||
64+
p->dtype() == DataType::kFloat32 ||
65+
p->dtype() == DataType::kFloat64) &&
66+
"`TopKTopPSampler` requires floating-point `p`");
67+
}
68+
69+
Tensor::Size batch_size_{0};
70+
71+
Tensor::Size vocab_size_{0};
72+
73+
DataType dtype_;
74+
};
75+
76+
} // namespace infini::ops
77+
78+
#endif // INFINI_OPS_BASE_TOP_K_TOP_P_SAMPLER_H_
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#ifndef INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_
3+
4+
#include <algorithm>
5+
#include <cassert>
6+
#include <cstdint>
7+
#include <cstring>
8+
#include <optional>
9+
#include <vector>
10+
11+
#include "acl/acl.h"
12+
#include "aclnn/aclnn_base.h"
13+
#include "aclnnop/aclnn_cast.h"
14+
#include "aclnnop/aclnn_top_k_top_p_sample.h"
15+
#include "base/top_k_top_p_sampler.h"
16+
#include "data_type.h"
17+
#include "native/ascend/common.h"
18+
#include "native/ascend/workspace_pool_.h"
19+
#include "operator.h"
20+
#include "tensor.h"
21+
22+
namespace infini::ops {
23+
24+
template <>
25+
class Operator<TopKTopPSampler, Device::Type::kAscend, 0>
26+
: public TopKTopPSampler {
27+
public:
28+
Operator(const Tensor logits, std::optional<Tensor> k,
29+
std::optional<Tensor> p, Tensor out)
30+
: TopKTopPSampler(logits, k, p, out) {
31+
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16) &&
32+
"`TopKTopPSampler` Ascend ACLNN path requires float16 or bfloat16 "
33+
"logits");
34+
assert(logits.IsContiguous() &&
35+
"`TopKTopPSampler` Ascend ACLNN path requires contiguous logits");
36+
assert(out.IsContiguous() &&
37+
"`TopKTopPSampler` Ascend ACLNN path requires contiguous output");
38+
ValidateHostTensor(k);
39+
ValidateHostTensor(p);
40+
41+
logits_cache_ = ascend::AclTensorCache(logits);
42+
top_k_cache_ = ascend::AclTensorCache({static_cast<int64_t>(batch_size_)},
43+
ACL_INT32, nullptr);
44+
top_p_cache_ = ascend::AclTensorCache({static_cast<int64_t>(batch_size_)},
45+
ascend::ToAclDtype(dtype_), nullptr);
46+
selected_idx_cache_ = ascend::AclTensorCache(
47+
{static_cast<int64_t>(batch_size_)}, ACL_INT64, nullptr);
48+
selected_logits_cache_ = ascend::AclTensorCache(
49+
{static_cast<int64_t>(batch_size_), static_cast<int64_t>(vocab_size_)},
50+
ACL_FLOAT, nullptr);
51+
out_cache_ = ascend::AclTensorCache(out);
52+
}
53+
54+
~Operator() {
55+
if (!ascend::IsAclRuntimeAlive()) return;
56+
57+
logits_cache_.release();
58+
top_k_cache_.release();
59+
top_p_cache_.release();
60+
selected_idx_cache_.release();
61+
selected_logits_cache_.release();
62+
out_cache_.release();
63+
}
64+
65+
void operator()(const Tensor logits, std::optional<Tensor> k,
66+
std::optional<Tensor> p, Tensor out) const override {
67+
assert(logits.IsContiguous() &&
68+
"`TopKTopPSampler` Ascend ACLNN path requires contiguous logits");
69+
assert(out.IsContiguous() &&
70+
"`TopKTopPSampler` Ascend ACLNN path requires contiguous output");
71+
assert(IsGreedy(k) &&
72+
"`TopKTopPSampler` Ascend ACLNN path supports `top_k == 1` only");
73+
74+
auto stream = static_cast<aclrtStream>(stream_);
75+
auto top_k_bytes = batch_size_ * kDataTypeToSize.at(DataType::kInt32);
76+
auto top_p_bytes = batch_size_ * kDataTypeToSize.at(dtype_);
77+
auto selected_idx_bytes =
78+
batch_size_ * kDataTypeToSize.at(DataType::kInt64);
79+
auto selected_logits_bytes =
80+
batch_size_ * vocab_size_ * kDataTypeToSize.at(DataType::kFloat32);
81+
82+
FillGreedyParams(p);
83+
84+
auto& top_k_arena = ascend::GetWorkspacePool().Ensure(
85+
stream, top_k_bytes, "top_k_top_p_sample_top_k");
86+
auto& top_p_arena = ascend::GetWorkspacePool().Ensure(
87+
stream, top_p_bytes, "top_k_top_p_sample_top_p");
88+
auto ret = aclrtMemcpy(top_k_arena.buf, top_k_bytes, top_k_host_.data(),
89+
top_k_bytes, ACL_MEMCPY_HOST_TO_DEVICE);
90+
assert(ret == ACL_SUCCESS &&
91+
"`TopKTopPSampler`: copying `top_k` to Ascend failed");
92+
ret = aclrtMemcpy(top_p_arena.buf, top_p_bytes, top_p_host_.data(),
93+
top_p_bytes, ACL_MEMCPY_HOST_TO_DEVICE);
94+
assert(ret == ACL_SUCCESS &&
95+
"`TopKTopPSampler`: copying `top_p` to Ascend failed");
96+
97+
auto& selected_idx_arena = ascend::GetWorkspacePool().Ensure(
98+
stream, selected_idx_bytes, "top_k_top_p_sample_idx");
99+
auto& selected_logits_arena = ascend::GetWorkspacePool().Ensure(
100+
stream, selected_logits_bytes, "top_k_top_p_sample_logits");
101+
102+
auto t_logits = logits_cache_.get(const_cast<void*>(logits.data()));
103+
auto t_top_k = top_k_cache_.get(top_k_arena.buf);
104+
auto t_top_p = top_p_cache_.get(top_p_arena.buf);
105+
auto t_selected_idx = selected_idx_cache_.get(selected_idx_arena.buf);
106+
auto t_selected_logits =
107+
selected_logits_cache_.get(selected_logits_arena.buf);
108+
109+
if (!sample_exec_) {
110+
ret = aclnnTopKTopPSampleGetWorkspaceSize(
111+
t_logits, t_top_k, t_top_p,
112+
/*qOptional=*/nullptr, /*eps=*/1e-8, /*isNeedLogits=*/false,
113+
/*topKGuess=*/32, t_selected_idx, t_selected_logits, &sample_ws_size_,
114+
&sample_exec_);
115+
assert(ret == ACL_SUCCESS &&
116+
"`aclnnTopKTopPSampleGetWorkspaceSize` failed");
117+
aclSetAclOpExecutorRepeatable(sample_exec_);
118+
} else {
119+
aclSetInputTensorAddr(sample_exec_, 0, t_logits,
120+
const_cast<void*>(logits.data()));
121+
aclSetInputTensorAddr(sample_exec_, 1, t_top_k, top_k_arena.buf);
122+
aclSetInputTensorAddr(sample_exec_, 2, t_top_p, top_p_arena.buf);
123+
aclSetOutputTensorAddr(sample_exec_, 0, t_selected_idx,
124+
selected_idx_arena.buf);
125+
aclSetOutputTensorAddr(sample_exec_, 1, t_selected_logits,
126+
selected_logits_arena.buf);
127+
}
128+
129+
auto& sample_ws_arena = ascend::GetWorkspacePool().Ensure(
130+
stream, sample_ws_size_, "top_k_top_p_sample_workspace");
131+
ret = aclnnTopKTopPSample(sample_ws_arena.buf, sample_ws_size_,
132+
sample_exec_, stream);
133+
assert(ret == ACL_SUCCESS && "`aclnnTopKTopPSample` failed");
134+
135+
CastSelectedIdx(selected_idx_arena.buf, out);
136+
}
137+
138+
private:
139+
void ValidateHostTensor(std::optional<Tensor> tensor) const {
140+
if (!tensor.has_value()) return;
141+
142+
assert(tensor->device().type() == Device::Type::kCpu &&
143+
"`TopKTopPSampler` Ascend path currently requires host-side "
144+
"`k`/`p` tensors");
145+
assert(tensor->IsContiguous() &&
146+
"`TopKTopPSampler` Ascend path requires contiguous `k`/`p` "
147+
"tensors");
148+
}
149+
150+
bool IsGreedy(std::optional<Tensor> k) const {
151+
if (!k.has_value()) return false;
152+
153+
for (Tensor::Size row = 0; row < batch_size_; ++row) {
154+
if (GetK(k, row) != 1) return false;
155+
}
156+
157+
return true;
158+
}
159+
160+
void CastSelectedIdx(void* selected_idx, Tensor out) const {
161+
auto stream = static_cast<aclrtStream>(stream_);
162+
auto t_selected_idx = selected_idx_cache_.get(selected_idx);
163+
auto t_out = out_cache_.get(out.data());
164+
165+
if (!cast_exec_) {
166+
auto ret = aclnnCastGetWorkspaceSize(t_selected_idx, ACL_INT32, t_out,
167+
&cast_ws_size_, &cast_exec_);
168+
assert(ret == ACL_SUCCESS && "`aclnnCastGetWorkspaceSize` failed");
169+
aclSetAclOpExecutorRepeatable(cast_exec_);
170+
} else {
171+
aclSetInputTensorAddr(cast_exec_, 0, t_selected_idx, selected_idx);
172+
aclSetOutputTensorAddr(cast_exec_, 0, t_out, out.data());
173+
}
174+
175+
auto& cast_ws_arena = ascend::GetWorkspacePool().Ensure(
176+
stream, cast_ws_size_, "top_k_top_p_sample_cast_workspace");
177+
auto ret = aclnnCast(cast_ws_arena.buf, cast_ws_size_, cast_exec_, stream);
178+
assert(ret == ACL_SUCCESS && "`aclnnCast` failed");
179+
}
180+
181+
void FillGreedyParams(std::optional<Tensor> p) const {
182+
top_k_host_.assign(batch_size_, 1);
183+
top_p_host_.resize(batch_size_ * kDataTypeToSize.at(dtype_));
184+
185+
for (Tensor::Size row = 0; row < batch_size_; ++row) {
186+
auto value = static_cast<float>(GetP(p, row));
187+
auto* dst = top_p_host_.data() + row * kDataTypeToSize.at(dtype_);
188+
189+
if (dtype_ == DataType::kFloat16) {
190+
auto converted = Float16::FromFloat(value);
191+
std::memcpy(dst, &converted, sizeof(converted));
192+
} else {
193+
auto converted = BFloat16::FromFloat(value);
194+
std::memcpy(dst, &converted, sizeof(converted));
195+
}
196+
}
197+
}
198+
199+
int64_t GetK(std::optional<Tensor> k, Tensor::Size row) const {
200+
if (!k.has_value()) return static_cast<int64_t>(vocab_size_);
201+
202+
const auto offset = k->size(0) == 1 ? 0 : row;
203+
int64_t value = 0;
204+
if (k->dtype() == DataType::kInt32) {
205+
value = static_cast<const int32_t*>(k->data())[offset];
206+
} else {
207+
value = static_cast<const int64_t*>(k->data())[offset];
208+
}
209+
210+
if (value <= 0) return static_cast<int64_t>(vocab_size_);
211+
return std::min<int64_t>(value, static_cast<int64_t>(vocab_size_));
212+
}
213+
214+
double GetP(std::optional<Tensor> p, Tensor::Size row) const {
215+
if (!p.has_value()) return 1.0;
216+
217+
const auto offset = p->size(0) == 1 ? 0 : row;
218+
double value = 1.0;
219+
switch (p->dtype()) {
220+
case DataType::kFloat16:
221+
value = static_cast<const Float16*>(p->data())[offset].ToFloat();
222+
break;
223+
case DataType::kBFloat16:
224+
value = static_cast<const BFloat16*>(p->data())[offset].ToFloat();
225+
break;
226+
case DataType::kFloat32:
227+
value = static_cast<const float*>(p->data())[offset];
228+
break;
229+
case DataType::kFloat64:
230+
value = static_cast<const double*>(p->data())[offset];
231+
break;
232+
default:
233+
assert(false && "`TopKTopPSampler` has unsupported `p` dtype");
234+
}
235+
236+
if (value <= 0.0 || value > 1.0) return 1.0;
237+
return value;
238+
}
239+
240+
mutable ascend::AclTensorCache logits_cache_;
241+
242+
mutable ascend::AclTensorCache top_k_cache_;
243+
244+
mutable ascend::AclTensorCache top_p_cache_;
245+
246+
mutable ascend::AclTensorCache selected_idx_cache_;
247+
248+
mutable ascend::AclTensorCache selected_logits_cache_;
249+
250+
mutable ascend::AclTensorCache out_cache_;
251+
252+
mutable std::vector<int32_t> top_k_host_;
253+
254+
mutable std::vector<std::uint8_t> top_p_host_;
255+
256+
mutable aclOpExecutor* sample_exec_ = nullptr;
257+
258+
mutable uint64_t sample_ws_size_ = 0;
259+
260+
mutable aclOpExecutor* cast_exec_ = nullptr;
261+
262+
mutable uint64_t cast_ws_size_ = 0;
263+
};
264+
265+
} // namespace infini::ops
266+
267+
#endif // INFINI_OPS_ASCEND_TOP_K_TOP_P_SAMPLER_KERNEL_H_

0 commit comments

Comments
 (0)