Skip to content

Commit d3bae33

Browse files
committed
issue/900 - adapt to graph and adjust test script
1 parent 4615ecf commit d3bae33

7 files changed

Lines changed: 334 additions & 819 deletions

File tree

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
#pragma once
22

3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
35
#include "common/op.hpp"
46

57
namespace infinicore::op {
68

7-
class Embedding {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor out, Tensor input, Tensor weight);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
9+
INFINICORE_GRAPH_OP_CLASS(Embedding, Tensor, const Tensor &, const Tensor &);
1310

14-
Tensor embedding(Tensor input, Tensor weight);
15-
void embedding_(Tensor out, Tensor input, Tensor weight);
11+
Tensor embedding(const Tensor &input, const Tensor &weight);
12+
void embedding_(Tensor out, const Tensor &input, const Tensor &weight);
1613
} // namespace infinicore::op

src/infinicore/ops/embedding/embedding.cc

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,19 @@
55
#include <stdexcept>
66

77
namespace infinicore::op {
8+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);
89

9-
common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
10-
static common::OpDispatcher<Embedding::schema> dispatcher_;
11-
return dispatcher_;
12-
}
13-
14-
void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
15-
// Check that all tensors are on the same device
16-
// This is critical: if input is on CPU while out/weight are on GPU,
17-
// passing CPU pointer to CUDA kernel will cause memory access errors
10+
Embedding::Embedding(Tensor out, const Tensor &input, const Tensor &weight) {
1811
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
12+
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, input, weight);
13+
}
1914

20-
// Set device context
21-
infinicore::context::setDevice(out->device());
22-
23-
// Use dispatcher to lookup kernel (infiniop implementation)
24-
dispatcher().lookup(out->device().getType())(out, input, weight);
15+
void Embedding::execute(Tensor out, const Tensor &input, const Tensor &weight) {
16+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Embedding, out, input, weight);
2517
}
2618

27-
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
28-
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
19+
Tensor embedding(const Tensor &input, // LongTensor of arbitrary shape containing the indices to extract
20+
const Tensor &weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
2921
) {
3022
auto input_shape = input->shape();
3123
auto weight_shape = weight->shape();
@@ -40,7 +32,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
4032
return inputs_embeds;
4133
}
4234

43-
void embedding_(Tensor out, Tensor input, Tensor weight) {
35+
void embedding_(Tensor out, const Tensor &input, const Tensor &weight) {
4436
Embedding::execute(out, input, weight);
4537
}
4638

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,44 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
3-
#include "infinicore/ops/common/cache.hpp"
1+
#include "../infiniop_impl.hpp"
42
#include "infinicore/ops/embedding.hpp"
5-
#include <infiniop.h>
63

74
namespace infinicore::op::embedding_impl::infiniop {
85

9-
thread_local common::OpCache<size_t, infiniopEmbeddingDescriptor_t> caches(
10-
100, // capacity
11-
[](infiniopEmbeddingDescriptor_t &desc) {
12-
if (desc != nullptr) {
13-
INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc));
14-
desc = nullptr;
15-
}
16-
});
6+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100);
177

18-
void calculate(Tensor out, Tensor input, Tensor weight) {
8+
struct PlannedMeta {
9+
std::shared_ptr<Descriptor> descriptor;
10+
graph::GraphTensor out, input, weight;
11+
};
12+
13+
void *plan(Tensor out, const Tensor &input, const Tensor &weight) {
1914
size_t seed = hash_combine(out, input, weight);
2015

21-
auto device = context::getDevice();
22-
auto &cache = caches.getCache(device);
16+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
17+
Descriptor, descriptor, Embedding,
18+
seed, out->desc(), input->desc(), weight->desc());
19+
20+
auto planned = new PlannedMeta{
21+
descriptor,
22+
graph::GraphTensor(out),
23+
graph::GraphTensor(input),
24+
graph::GraphTensor(weight)};
2325

24-
auto desc_opt = cache.get(seed);
25-
infiniopEmbeddingDescriptor_t desc = nullptr;
26+
return planned;
27+
}
2628

27-
if (!desc_opt) {
28-
INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor(
29-
context::getInfiniopHandle(device), &desc,
30-
out->desc(), input->desc(), weight->desc()));
31-
cache.put(seed, desc);
32-
} else {
33-
desc = *desc_opt;
34-
}
29+
void run(void *planned_meta) {
30+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
3531

3632
INFINICORE_CHECK_ERROR(infiniopEmbedding(
37-
desc,
38-
out->data(),
39-
input->data(),
40-
weight->data(),
41-
context::getStream()));
33+
planned->descriptor->desc,
34+
planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream()));
35+
}
36+
37+
void cleanup(void **planned_meta_ptr) {
38+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
39+
*planned_meta_ptr = nullptr;
4240
}
4341

44-
static bool registered = []() {
45-
Embedding::dispatcher().registerAll(&calculate, false);
46-
return true;
47-
}();
42+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup);
4843

4944
} // namespace infinicore::op::embedding_impl::infiniop

0 commit comments

Comments
 (0)