|
1 | | -#include "../../utils.hpp" |
2 | | -#include "infinicore/common/hash.hpp" |
3 | 1 | #include "infinicore/ops/add_rms_norm.hpp" |
4 | | -#include "infinicore/ops/common/cache.hpp" |
5 | | -#include <infiniop.h> |
| 2 | + |
| 3 | +#include "../infiniop_impl.hpp" |
6 | 4 |
|
7 | 5 | namespace infinicore::op::add_rms_norm_impl::infiniop { |
8 | 6 |
|
9 | | -thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches( |
10 | | - 100, // capacity |
11 | | - [](infiniopAddRMSNormDescriptor_t &desc) { |
12 | | - if (desc != nullptr) { |
13 | | - INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); |
14 | | - desc = nullptr; |
15 | | - } |
16 | | - }); |
| 7 | +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100); |
| 8 | + |
| 9 | +struct PlannedMeta { |
| 10 | + std::shared_ptr<Descriptor> descriptor; |
| 11 | + graph::GraphTensor workspace, out, residual, a, b, weight; |
| 12 | + float epsilon; |
| 13 | +}; |
17 | 14 |
|
18 | | -void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { |
| 15 | +void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { |
19 | 16 | size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); |
20 | 17 |
|
21 | | - auto device = context::getDevice(); |
22 | | - auto &cache = caches.getCache(device); |
| 18 | + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( |
| 19 | + Descriptor, descriptor, AddRMSNorm, |
| 20 | + seed, y->desc(), residual_out->desc(), |
| 21 | + a->desc(), b->desc(), weight->desc(), epsilon); |
| 22 | + |
| 23 | + INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor); |
23 | 24 |
|
24 | | - auto desc_opt = cache.get(seed); |
25 | | - infiniopAddRMSNormDescriptor_t desc = nullptr; |
| 25 | + auto planned = new PlannedMeta{ |
| 26 | + descriptor, |
| 27 | + graph::GraphTensor(workspace), |
| 28 | + graph::GraphTensor(y), |
| 29 | + graph::GraphTensor(residual_out), |
| 30 | + graph::GraphTensor(a), |
| 31 | + graph::GraphTensor(b), |
| 32 | + graph::GraphTensor(weight), |
| 33 | + epsilon}; |
26 | 34 |
|
27 | | - if (!desc_opt) { |
28 | | - INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( |
29 | | - context::getInfiniopHandle(device), &desc, |
30 | | - y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); |
31 | | - cache.put(seed, desc); |
32 | | - } else { |
33 | | - desc = *desc_opt; |
34 | | - } |
| 35 | + return planned; |
| 36 | +} |
35 | 37 |
|
36 | | - size_t workspace_size = 0; |
37 | | - INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); |
38 | | - std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); |
| 38 | +void run(void *planned_meta) { |
| 39 | + auto planned = reinterpret_cast<PlannedMeta *>(planned_meta); |
39 | 40 |
|
40 | 41 | INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( |
41 | | - desc, workspace->data(), workspace_size, |
42 | | - y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); |
| 42 | + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), |
| 43 | + planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream())); |
| 44 | +} |
| 45 | + |
| 46 | +void cleanup(void **planned_meta_ptr) { |
| 47 | + delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr); |
| 48 | + *planned_meta_ptr = nullptr; |
43 | 49 | } |
44 | 50 |
|
45 | | -static bool registered = []() { |
46 | | - AddRMSNorm::dispatcher().registerAll(&calculate, false); |
47 | | - return true; |
48 | | -}(); |
| 51 | +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup); |
49 | 52 |
|
50 | 53 | } // namespace infinicore::op::add_rms_norm_impl::infiniop |
0 commit comments