|
1 | | -#include "../../utils.hpp" |
2 | | -#include "infinicore/common/hash.hpp" |
3 | 1 | #include "infinicore/ops/asum.hpp" |
4 | | -#include "infinicore/ops/common/cache.hpp" |
5 | | -#include <infiniop.h> |
| 2 | + |
| 3 | +#include "../infiniop_impl.hpp" |
6 | 4 |
|
7 | 5 | namespace infinicore::op::asum_impl::infiniop { |
8 | 6 |
|
9 | | -thread_local common::OpCache<size_t, infiniopAsumDescriptor_t> caches( |
10 | | - 100, // capacity |
11 | | - [](infiniopAsumDescriptor_t &desc) { |
12 | | - if (desc != nullptr) { |
13 | | - INFINICORE_CHECK_ERROR(infiniopDestroyAsumDescriptor(desc)); |
14 | | - desc = nullptr; |
15 | | - } |
16 | | - }); |
| 7 | +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Asum, 100); |
17 | 8 |
|
18 | | -void calculate(Tensor result, Tensor x) { |
19 | | - size_t seed = hash_combine(result, x); |
| 9 | +struct PlannedMeta { |
| 10 | + std::shared_ptr<Descriptor> descriptor; |
| 11 | + graph::GraphTensor workspace, x, result; |
| 12 | +}; |
20 | 13 |
|
21 | | - auto device_type = context::getDevice().getType(); |
22 | | - auto device_index = context::getDevice().getIndex(); |
| 14 | +void *plan(const Tensor &x, Tensor result) { |
| 15 | + size_t seed = hash_combine(x, result); |
23 | 16 |
|
24 | | - auto &cache = caches.getCache(device_type, device_index); |
| 17 | + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( |
| 18 | + Descriptor, descriptor, Asum, |
| 19 | + seed, |
| 20 | + x->desc(), result->desc()); |
25 | 21 |
|
26 | | - auto desc_opt = cache.get(seed); |
27 | | - infiniopAsumDescriptor_t desc = nullptr; |
| 22 | + INFINIOP_WORKSPACE_TENSOR(workspace, Asum, descriptor); |
28 | 23 |
|
29 | | - if (!desc_opt) { |
30 | | - INFINICORE_CHECK_ERROR(infiniopCreateAsumDescriptor( |
31 | | - context::getInfiniopHandle(result->device()), &desc, |
32 | | - x->desc(), result->desc())); |
33 | | - cache.put(seed, desc); |
34 | | - } else { |
35 | | - desc = *desc_opt; |
36 | | - } |
| 24 | + return new PlannedMeta{ |
| 25 | + descriptor, |
| 26 | + graph::GraphTensor(workspace), |
| 27 | + graph::GraphTensor(x), |
| 28 | + graph::GraphTensor(result)}; |
| 29 | +} |
37 | 30 |
|
38 | | - size_t workspace_size = 0; |
39 | | - INFINICORE_CHECK_ERROR(infiniopGetAsumWorkspaceSize(desc, &workspace_size)); |
40 | | - std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); |
| 31 | +void run(void *planned_meta) { |
| 32 | + auto planned = reinterpret_cast<PlannedMeta *>(planned_meta); |
41 | 33 |
|
42 | 34 | INFINICORE_CHECK_ERROR(infiniopAsum( |
43 | | - desc, workspace->data(), workspace_size, |
44 | | - x->data(), result->data(), context::getStream())); |
| 35 | + planned->descriptor->desc, |
| 36 | + planned->workspace->data(), |
| 37 | + planned->workspace->numel(), |
| 38 | + planned->x->data(), |
| 39 | + planned->result->data(), |
| 40 | + context::getStream())); |
| 41 | +} |
| 42 | + |
| 43 | +void cleanup(void **planned_meta_ptr) { |
| 44 | + delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr); |
| 45 | + *planned_meta_ptr = nullptr; |
45 | 46 | } |
46 | 47 |
|
47 | | -static bool registered = []() { |
48 | | - Asum::dispatcher().registerDevice({Device::Type::CPU, |
49 | | - Device::Type::CAMBRICON, |
50 | | - Device::Type::METAX}, |
51 | | - &calculate, |
52 | | - false); |
53 | | - return true; |
54 | | -}(); |
| 48 | +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Asum, &plan, &run, &cleanup); |
55 | 49 |
|
56 | 50 | } // namespace infinicore::op::asum_impl::infiniop |
0 commit comments