Skip to content

Commit fa7273d

Browse files
committed
add top_k router and casting
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent eacb0c4 commit fa7273d

File tree

21 files changed

+867
-51
lines changed

21 files changed

+867
-51
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "ops/binary_cross_entropy_with_logits.hpp"
1515
#include "ops/causal_softmax.hpp"
1616
#include "ops/cdist.hpp"
17+
#include "ops/convert_to_f32.hpp"
1718
#include "ops/cross_entropy.hpp"
1819
#include "ops/embedding.hpp"
1920
#include "ops/flash_attention.hpp"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class ConvertToF32 {
9+
public:
10+
using schema = void (*)(Tensor, Tensor);
11+
static void execute(Tensor output, Tensor input);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
Tensor convert_to_f32(Tensor input);
16+
void convert_to_f32_(Tensor output, Tensor input);
17+
18+
} // namespace infinicore::op

include/infinicore/ops/topkrouter.hpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,3 @@ void topkrouter_(Tensor values_output,
3434
size_t topk);
3535

3636
} // namespace infinicore::op
37-
38-
#pragma once
39-
40-
#include "../device.hpp"
41-
#include "common/op.hpp"
42-
43-
namespace infinicore::op {
44-
45-
class TopKRouter {
46-
public:
47-
// values_output: [N, topk] float32
48-
// indices_output: [N, topk] int32
49-
// input: [N, width] float16/bfloat16/float32
50-
// correction_bias: [width] float32
51-
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, size_t);
52-
static void execute(Tensor values_output,
53-
Tensor indices_output,
54-
Tensor input,
55-
Tensor correction_bias,
56-
float routed_scaling_factor,
57-
size_t topk);
58-
static common::OpDispatcher<schema> &dispatcher();
59-
};
60-
61-
std::pair<Tensor, Tensor> topkrouter(Tensor input,
62-
Tensor correction_bias,
63-
float routed_scaling_factor,
64-
size_t topk);
65-
66-
void topkrouter_(Tensor values_output,
67-
Tensor indices_output,
68-
Tensor input,
69-
Tensor correction_bias,
70-
float routed_scaling_factor,
71-
size_t topk);
72-
73-
} // namespace infinicore::op
74-

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "infiniop/ops/cdist.h"
2727
#include "infiniop/ops/clip.h"
2828
#include "infiniop/ops/conv.h"
29+
#include "infiniop/ops/convert_to_f32.h"
2930
#include "infiniop/ops/cross_entropy.h"
3031
#include "infiniop/ops/dequant/per_tensor_dequant_int8.h"
3132
#include "infiniop/ops/dequantize_awq.h"
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef __INFINIOP_CONVERT_TO_F32_API_H__
2+
#define __INFINIOP_CONVERT_TO_F32_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopConvertToF32Descriptor_t;
7+
8+
__INFINI_C __export infiniStatus_t infiniopCreateConvertToF32Descriptor(infiniopHandle_t handle,
9+
infiniopConvertToF32Descriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t y,
11+
infiniopTensorDescriptor_t x);
12+
13+
__INFINI_C __export infiniStatus_t infiniopGetConvertToF32WorkspaceSize(infiniopConvertToF32Descriptor_t desc, size_t *size);
14+
15+
__INFINI_C __export infiniStatus_t infiniopConvertToF32(infiniopConvertToF32Descriptor_t desc,
16+
void *workspace,
17+
size_t workspace_size,
18+
void *y,
19+
const void *x,
20+
void *stream);
21+
22+
__INFINI_C __export infiniStatus_t infiniopDestroyConvertToF32Descriptor(infiniopConvertToF32Descriptor_t desc);
23+
24+
#endif

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from infinicore.ops.broadcast_to import broadcast_to
7070
from infinicore.ops.cat import cat
7171
from infinicore.ops.cdist import cdist
72+
from infinicore.ops.convert_to_f32 import convert_to_f32
7273
from infinicore.ops.cross_entropy import cross_entropy
7374
from infinicore.ops.diff import diff
7475
from infinicore.ops.digamma import digamma
@@ -216,6 +217,7 @@
216217
"unsqueeze",
217218
"rearrange",
218219
"cross_entropy",
220+
"convert_to_f32",
219221
"tan",
220222
"empty",
221223
"empty_like",
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def convert_to_f32(input, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.convert_to_f32(input._underlying))
8+
9+
_infinicore.convert_to_f32_(out._underlying, input._underlying)
10+
11+
return out
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "infinicore/ops/convert_to_f32.hpp"
2+
3+
#include "../../utils.hpp"
4+
#include "infinicore/dtype.hpp"
5+
6+
#include <stdexcept>
7+
8+
namespace infinicore::op {
9+
10+
common::OpDispatcher<ConvertToF32::schema> &ConvertToF32::dispatcher() {
11+
static common::OpDispatcher<ConvertToF32::schema> dispatcher_;
12+
return dispatcher_;
13+
};
14+
15+
void ConvertToF32::execute(Tensor output, Tensor input) {
16+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
17+
if (output->dtype() != DataType::F32) {
18+
throw std::runtime_error("convert_to_f32: output dtype must be F32");
19+
}
20+
infinicore::context::setDevice(output->device());
21+
auto device_type = output->device().getType();
22+
auto func = dispatcher().lookup(device_type);
23+
24+
if (func == nullptr) {
25+
throw std::runtime_error("No convert_to_f32 implementation found for device type: "
26+
+ std::to_string(static_cast<int>(device_type)));
27+
}
28+
29+
func(output, input);
30+
}
31+
32+
Tensor convert_to_f32(Tensor input) {
33+
Shape shape = input->shape();
34+
auto output = Tensor::empty(shape, DataType::F32, input->device());
35+
convert_to_f32_(output, input);
36+
return output;
37+
}
38+
39+
void convert_to_f32_(Tensor output, Tensor input) {
40+
ConvertToF32::execute(output, input);
41+
}
42+
43+
} // namespace infinicore::op
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/convert_to_f32.hpp"
5+
6+
#include <infiniop.h>
7+
8+
namespace infinicore::op::convert_to_f32_impl::infiniop {
9+
10+
thread_local common::OpCache<size_t, infiniopConvertToF32Descriptor_t> caches(
11+
100,
12+
[](infiniopConvertToF32Descriptor_t &desc) {
13+
if (desc != nullptr) {
14+
INFINICORE_CHECK_ERROR(infiniopDestroyConvertToF32Descriptor(desc));
15+
desc = nullptr;
16+
}
17+
});
18+
19+
void calculate(Tensor output, Tensor input) {
20+
size_t seed = hash_combine(output, input);
21+
22+
auto device = context::getDevice();
23+
auto &cache = caches.getCache(device);
24+
25+
auto desc_opt = cache.get(seed);
26+
infiniopConvertToF32Descriptor_t desc = nullptr;
27+
28+
if (!desc_opt) {
29+
INFINICORE_CHECK_ERROR(infiniopCreateConvertToF32Descriptor(
30+
context::getInfiniopHandle(device), &desc,
31+
output->desc(), input->desc()));
32+
cache.put(seed, desc);
33+
} else {
34+
desc = *desc_opt;
35+
}
36+
37+
size_t workspace_size = 0;
38+
INFINICORE_CHECK_ERROR(infiniopGetConvertToF32WorkspaceSize(desc, &workspace_size));
39+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
40+
41+
INFINICORE_CHECK_ERROR(infiniopConvertToF32(
42+
desc, workspace->data(), workspace_size,
43+
output->data(), input->data(), context::getStream()));
44+
}
45+
46+
static bool registered = []() {
47+
ConvertToF32::dispatcher().registerAll(&calculate, false);
48+
return true;
49+
}();
50+
51+
} // namespace infinicore::op::convert_to_f32_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "ops/cat.hpp"
2929
#include "ops/causal_softmax.hpp"
3030
#include "ops/cdist.hpp"
31+
#include "ops/convert_to_f32.hpp"
3132
#include "ops/cross_entropy.hpp"
3233
#include "ops/diff.hpp"
3334
#include "ops/digamma.hpp"
@@ -163,6 +164,7 @@ inline void bind(py::module &m) {
163164
bind_pad(m);
164165
bind_prelu(m);
165166
bind_random_sample(m);
167+
bind_convert_to_f32(m);
166168
bind_cross_entropy(m);
167169
bind_hypot(m);
168170
bind_take(m);

0 commit comments

Comments
 (0)