Skip to content

Commit 77cbb09

Browse files
committed
Add sigmoid op wrapper
Made-with: Cursor
1 parent e3dee1b commit 77cbb09

4 files changed

Lines changed: 111 additions & 0 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ops/reciprocal.hpp"
3535
#include "ops/rms_norm.hpp"
3636
#include "ops/rope.hpp"
37+
#include "ops/sigmoid.hpp"
3738
#include "ops/silu.hpp"
3839
#include "ops/silu_and_mul.hpp"
3940
#include "ops/swiglu.hpp"

include/infinicore/ops/sigmoid.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Sigmoid {
9+
public:
10+
using schema = void (*)(Tensor, Tensor);
11+
static void execute(Tensor output, Tensor input);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
Tensor sigmoid(Tensor input);
16+
void sigmoid_(Tensor output, Tensor input);
17+
18+
} // namespace infinicore::op
19+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "infinicore/ops/sigmoid.hpp"
2+
3+
#include "../../utils.hpp"
4+
5+
#include <stdexcept>
6+
7+
namespace infinicore::op {
8+
9+
common::OpDispatcher<Sigmoid::schema> &Sigmoid::dispatcher() {
10+
static common::OpDispatcher<Sigmoid::schema> dispatcher_;
11+
return dispatcher_;
12+
};
13+
14+
void Sigmoid::execute(Tensor output, Tensor input) {
15+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input);
16+
infinicore::context::setDevice(output->device());
17+
auto device_type = output->device().getType();
18+
auto func = dispatcher().lookup(device_type);
19+
20+
if (func == nullptr) {
21+
throw std::runtime_error("No Sigmoid implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
22+
}
23+
24+
func(output, input);
25+
}
26+
27+
Tensor sigmoid(Tensor input) {
28+
Shape shape = input->shape();
29+
auto output = Tensor::empty(shape, input->dtype(), input->device());
30+
sigmoid_(output, input);
31+
return output;
32+
}
33+
34+
void sigmoid_(Tensor output, Tensor input) {
35+
Sigmoid::execute(output, input);
36+
}
37+
38+
} // namespace infinicore::op
39+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/sigmoid.hpp"
5+
6+
#include <infiniop.h>
7+
8+
namespace infinicore::op::sigmoid_impl::infiniop {
9+
10+
thread_local common::OpCache<size_t, infiniopSigmoidDescriptor_t> caches(
11+
100, // capacity
12+
[](infiniopSigmoidDescriptor_t &desc) {
13+
if (desc != nullptr) {
14+
INFINICORE_CHECK_ERROR(infiniopDestroySigmoidDescriptor(desc));
15+
desc = nullptr;
16+
}
17+
});
18+
19+
void calculate(Tensor output, Tensor input) {
20+
size_t seed = hash_combine(output, input);
21+
22+
auto device = context::getDevice();
23+
auto &cache = caches.getCache(device);
24+
25+
auto desc_opt = cache.get(seed);
26+
infiniopSigmoidDescriptor_t desc = nullptr;
27+
28+
if (!desc_opt) {
29+
INFINICORE_CHECK_ERROR(infiniopCreateSigmoidDescriptor(
30+
context::getInfiniopHandle(device), &desc,
31+
output->desc(), input->desc()));
32+
cache.put(seed, desc);
33+
} else {
34+
desc = *desc_opt;
35+
}
36+
37+
size_t workspace_size = 0;
38+
INFINICORE_CHECK_ERROR(infiniopGetSigmoidWorkspaceSize(desc, &workspace_size));
39+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
40+
41+
INFINICORE_CHECK_ERROR(infiniopSigmoid(
42+
desc, workspace->data(), workspace_size,
43+
output->data(), input->data(), context::getStream()));
44+
}
45+
46+
static bool registered = []() {
47+
Sigmoid::dispatcher().registerAll(&calculate, false);
48+
return true;
49+
}();
50+
51+
} // namespace infinicore::op::sigmoid_impl::infiniop
52+

0 commit comments

Comments
 (0)