Skip to content

Commit eacb0c4

Browse files
committed
Add topkrouter op wrapper
Made-with: Cursor
1 parent 77cbb09 commit eacb0c4

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
#include "ops/silu.hpp"
3939
#include "ops/silu_and_mul.hpp"
4040
#include "ops/swiglu.hpp"
41+
#include "ops/topkrouter.hpp"
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class TopKRouter {
9+
public:
10+
// values_output: [N, topk] float32
11+
// indices_output: [N, topk] int32
12+
// input: [N, width] float16/bfloat16/float32
13+
// correction_bias: [width] float32
14+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, size_t);
15+
static void execute(Tensor values_output,
16+
Tensor indices_output,
17+
Tensor input,
18+
Tensor correction_bias,
19+
float routed_scaling_factor,
20+
size_t topk);
21+
static common::OpDispatcher<schema> &dispatcher();
22+
};
23+
24+
std::pair<Tensor, Tensor> topkrouter(Tensor input,
25+
Tensor correction_bias,
26+
float routed_scaling_factor,
27+
size_t topk);
28+
29+
void topkrouter_(Tensor values_output,
30+
Tensor indices_output,
31+
Tensor input,
32+
Tensor correction_bias,
33+
float routed_scaling_factor,
34+
size_t topk);
35+
36+
} // namespace infinicore::op
37+
38+
#pragma once
39+
40+
#include "../device.hpp"
41+
#include "common/op.hpp"
42+
43+
namespace infinicore::op {
44+
45+
class TopKRouter {
46+
public:
47+
// values_output: [N, topk] float32
48+
// indices_output: [N, topk] int32
49+
// input: [N, width] float16/bfloat16/float32
50+
// correction_bias: [width] float32
51+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, size_t);
52+
static void execute(Tensor values_output,
53+
Tensor indices_output,
54+
Tensor input,
55+
Tensor correction_bias,
56+
float routed_scaling_factor,
57+
size_t topk);
58+
static common::OpDispatcher<schema> &dispatcher();
59+
};
60+
61+
std::pair<Tensor, Tensor> topkrouter(Tensor input,
62+
Tensor correction_bias,
63+
float routed_scaling_factor,
64+
size_t topk);
65+
66+
void topkrouter_(Tensor values_output,
67+
Tensor indices_output,
68+
Tensor input,
69+
Tensor correction_bias,
70+
float routed_scaling_factor,
71+
size_t topk);
72+
73+
} // namespace infinicore::op
74+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "infinicore/ops/topkrouter.hpp"
2+
3+
#include "../../utils.hpp"
4+
5+
#include <stdexcept>
6+
7+
namespace infinicore::op {
8+
9+
common::OpDispatcher<TopKRouter::schema> &TopKRouter::dispatcher() {
10+
static common::OpDispatcher<TopKRouter::schema> dispatcher_;
11+
return dispatcher_;
12+
};
13+
14+
void TopKRouter::execute(Tensor values_output,
15+
Tensor indices_output,
16+
Tensor input,
17+
Tensor correction_bias,
18+
float routed_scaling_factor,
19+
size_t topk) {
20+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(values_output, indices_output, input, correction_bias);
21+
infinicore::context::setDevice(input->device());
22+
auto device_type = input->device().getType();
23+
auto func = dispatcher().lookup(device_type);
24+
if (func == nullptr) {
25+
throw std::runtime_error("No TopKRouter implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
26+
}
27+
func(values_output, indices_output, input, correction_bias, routed_scaling_factor, topk);
28+
}
29+
30+
std::pair<Tensor, Tensor> topkrouter(Tensor input,
31+
Tensor correction_bias,
32+
float routed_scaling_factor,
33+
size_t topk) {
34+
// values: float32, indices: int32
35+
auto shape = input->shape();
36+
if (shape.size() != 2) {
37+
throw std::runtime_error("topkrouter: input must be 2D [N, width]");
38+
}
39+
Tensor values = Tensor::empty({shape[0], topk}, DataType::F32, input->device());
40+
Tensor indices = Tensor::empty({shape[0], topk}, DataType::I32, input->device());
41+
topkrouter_(values, indices, input, correction_bias, routed_scaling_factor, topk);
42+
return {values, indices};
43+
}
44+
45+
void topkrouter_(Tensor values_output,
46+
Tensor indices_output,
47+
Tensor input,
48+
Tensor correction_bias,
49+
float routed_scaling_factor,
50+
size_t topk) {
51+
TopKRouter::execute(values_output, indices_output, input, correction_bias, routed_scaling_factor, topk);
52+
}
53+
54+
} // namespace infinicore::op
55+
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/topkrouter.hpp"
5+
6+
#include <infiniop.h>
7+
8+
namespace infinicore::op::topkrouter_impl::infiniop {
9+
10+
thread_local common::OpCache<size_t, infiniopTopkrouterDescriptor_t> caches(
11+
100,
12+
[](infiniopTopkrouterDescriptor_t &desc) {
13+
if (desc != nullptr) {
14+
INFINICORE_CHECK_ERROR(infiniopDestroyTopkrouterDescriptor(desc));
15+
desc = nullptr;
16+
}
17+
});
18+
19+
void calculate(Tensor values_output,
20+
Tensor indices_output,
21+
Tensor input,
22+
Tensor correction_bias,
23+
float routed_scaling_factor,
24+
size_t topk) {
25+
size_t seed = hash_combine(values_output, indices_output, input, correction_bias, (size_t)topk);
26+
27+
auto device = context::getDevice();
28+
auto &cache = caches.getCache(device);
29+
30+
auto desc_opt = cache.get(seed);
31+
infiniopTopkrouterDescriptor_t desc = nullptr;
32+
33+
if (!desc_opt) {
34+
INFINICORE_CHECK_ERROR(infiniopCreateTopkrouterDescriptor(
35+
context::getInfiniopHandle(device), &desc,
36+
input->desc(), correction_bias->desc()));
37+
cache.put(seed, desc);
38+
} else {
39+
desc = *desc_opt;
40+
}
41+
42+
size_t workspace_size = 0;
43+
INFINICORE_CHECK_ERROR(infiniopGetTopkrouterWorkspaceSize(desc, &workspace_size));
44+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
45+
46+
INFINICORE_CHECK_ERROR(infiniopTopkrouter(
47+
desc, workspace->data(), workspace_size,
48+
values_output->data(), indices_output->data(),
49+
input->data(), correction_bias->data(),
50+
routed_scaling_factor, topk, context::getStream()));
51+
}
52+
53+
static bool registered = []() {
54+
TopKRouter::dispatcher().registerAll(&calculate, false);
55+
return true;
56+
}();
57+
58+
} // namespace infinicore::op::topkrouter_impl::infiniop
59+

0 commit comments

Comments
 (0)