|
1 | | -#include "../../utils.hpp" |
2 | | -#include "infinicore/common/hash.hpp" |
3 | 1 | #include "infinicore/ops/axpy.hpp" |
4 | | -#include "infinicore/ops/common/cache.hpp" |
5 | | -#include <infiniop.h> |
| 2 | + |
| 3 | +#include "../infiniop_impl.hpp" |
6 | 4 |
|
7 | 5 | namespace infinicore::op::axpy_impl::infiniop { |
8 | 6 |
|
9 | | -thread_local common::OpCache<size_t, infiniopAxpyDescriptor_t> caches( |
10 | | - 100, // capacity |
11 | | - [](infiniopAxpyDescriptor_t &desc) { |
12 | | - if (desc != nullptr) { |
13 | | - INFINICORE_CHECK_ERROR(infiniopDestroyAxpyDescriptor(desc)); |
14 | | - desc = nullptr; |
15 | | - } |
16 | | - }); |
| 7 | +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Axpy, 100); |
17 | 8 |
|
18 | | -void calculate(Tensor alpha, Tensor x, Tensor y) { |
19 | | - size_t seed = hash_combine(alpha, x, y); |
| 9 | +struct PlannedMeta { |
| 10 | + std::shared_ptr<Descriptor> descriptor; |
| 11 | + graph::GraphTensor workspace, alpha, x, y; |
| 12 | +}; |
20 | 13 |
|
21 | | - auto device_type = context::getDevice().getType(); |
22 | | - auto device_index = context::getDevice().getIndex(); |
| 14 | +void *plan(const Tensor &alpha, const Tensor &x, Tensor y) { |
| 15 | + size_t seed = hash_combine(y, alpha, x); |
23 | 16 |
|
24 | | - auto &cache = caches.getCache(device_type, device_index); |
| 17 | + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( |
| 18 | + Descriptor, descriptor, Axpy, |
| 19 | + seed, |
| 20 | + alpha->desc(), x->desc(), y->desc()); |
25 | 21 |
|
26 | | - auto desc_opt = cache.get(seed); |
27 | | - infiniopAxpyDescriptor_t desc = nullptr; |
| 22 | + INFINIOP_WORKSPACE_TENSOR(workspace, Axpy, descriptor); |
28 | 23 |
|
29 | | - if (!desc_opt) { |
30 | | - INFINICORE_CHECK_ERROR(infiniopCreateAxpyDescriptor( |
31 | | - context::getInfiniopHandle(y->device()), &desc, |
32 | | - alpha->desc(), x->desc(), y->desc())); |
33 | | - cache.put(seed, desc); |
34 | | - } else { |
35 | | - desc = *desc_opt; |
36 | | - } |
| 24 | + return new PlannedMeta{ |
| 25 | + descriptor, |
| 26 | + graph::GraphTensor(workspace), |
| 27 | + graph::GraphTensor(alpha), |
| 28 | + graph::GraphTensor(x), |
| 29 | + graph::GraphTensor(y)}; |
| 30 | +} |
37 | 31 |
|
38 | | - size_t workspace_size = 0; |
39 | | - INFINICORE_CHECK_ERROR(infiniopGetAxpyWorkspaceSize(desc, &workspace_size)); |
40 | | - std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); |
| 32 | +void run(void *planned_meta) { |
| 33 | + auto planned = reinterpret_cast<PlannedMeta *>(planned_meta); |
41 | 34 |
|
42 | 35 | INFINICORE_CHECK_ERROR(infiniopAxpy( |
43 | | - desc, workspace->data(), workspace_size, |
44 | | - alpha->data(), x->data(), y->data(), context::getStream())); |
| 36 | + planned->descriptor->desc, |
| 37 | + planned->workspace->data(), |
| 38 | + planned->workspace->numel(), |
| 39 | + planned->alpha->data(), |
| 40 | + planned->x->data(), |
| 41 | + planned->y->data(), |
| 42 | + context::getStream())); |
| 43 | +} |
| 44 | + |
| 45 | +void cleanup(void **planned_meta_ptr) { |
| 46 | + delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr); |
| 47 | + *planned_meta_ptr = nullptr; |
45 | 48 | } |
46 | 49 |
|
47 | | -static bool registered = []() { |
48 | | - Axpy::dispatcher().registerDevice({Device::Type::CPU, |
49 | | - Device::Type::CAMBRICON, |
50 | | - Device::Type::METAX}, |
51 | | - &calculate, |
52 | | - false); |
53 | | - return true; |
54 | | -}(); |
| 50 | +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Axpy, &plan, &run, &cleanup); |
55 | 51 |
|
56 | 52 | } // namespace infinicore::op::axpy_impl::infiniop |
0 commit comments