Skip to content

Commit 446d0a7

Browse files
committed
Update asum, blas_amax, blas_amin, blas_copy and swap to use the graph framework
1 parent 7317e42 commit 446d0a7

21 files changed

Lines changed: 250 additions & 297 deletions

File tree

include/infinicore/ops/asum.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Asum {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor result, Tensor x);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Asum, const Tensor &, Tensor);
1410

15-
Tensor asum(Tensor x);
16-
void asum_(Tensor result, Tensor x);
11+
Tensor asum(const Tensor &x);
12+
void asum_(const Tensor &x, Tensor result);
1713

1814
} // namespace infinicore::op
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class BlasAmax {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor result, Tensor x);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(BlasAmax, const Tensor &, Tensor);
1410

15-
Tensor blas_amax(Tensor x);
16-
void blas_amax_(Tensor result, Tensor x);
11+
Tensor blas_amax(const Tensor &x);
12+
void blas_amax_(const Tensor &x, Tensor result);
1713

1814
} // namespace infinicore::op
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class BlasAmin {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor result, Tensor x);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(BlasAmin, const Tensor &, Tensor);
1410

15-
Tensor blas_amin(Tensor x);
16-
void blas_amin_(Tensor result, Tensor x);
11+
Tensor blas_amin(const Tensor &x);
12+
void blas_amin_(const Tensor &x, Tensor result);
1713

1814
} // namespace infinicore::op
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class BlasCopy {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor x, Tensor y);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(BlasCopy, const Tensor &, Tensor);
1410

15-
void blas_copy_(Tensor x, Tensor y);
11+
void blas_copy_(const Tensor &x, Tensor y);
1612

1713
} // namespace infinicore::op

include/infinicore/ops/swap.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Swap {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor x, Tensor y);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Swap, Tensor, Tensor);
1410

1511
void swap_(Tensor x, Tensor y);
1612

python/infinicore/ops/asum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ def asum(x: Tensor, *, out=None):
66
if out is None:
77
return Tensor(_infinicore.asum(x._underlying))
88

9-
_infinicore.asum_(out._underlying, x._underlying)
9+
_infinicore.asum_(x._underlying, out._underlying)
10+
1011
return out

python/infinicore/ops/blas_amax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ def blas_amax(x: Tensor, *, out=None):
66
if out is None:
77
return Tensor(_infinicore.blas_amax(x._underlying))
88

9-
_infinicore.blas_amax_(out._underlying, x._underlying)
9+
_infinicore.blas_amax_(x._underlying, out._underlying)
10+
1011
return out

python/infinicore/ops/blas_amin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ def blas_amin(x: Tensor, *, out=None):
66
if out is None:
77
return Tensor(_infinicore.blas_amin(x._underlying))
88

9-
_infinicore.blas_amin_(out._underlying, x._underlying)
9+
_infinicore.blas_amin_(x._underlying, out._underlying)
10+
1011
return out

src/infinicore/ops/asum/asum.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@
44

55
namespace infinicore::op {
66

7-
common::OpDispatcher<Asum::schema> &Asum::dispatcher() {
8-
static common::OpDispatcher<Asum::schema> dispatcher_;
9-
return dispatcher_;
10-
};
11-
12-
void Asum::execute(Tensor result, Tensor x) {
13-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, x);
14-
infinicore::context::setDevice(result->device());
15-
dispatcher().lookup(result->device().getType())(result, x);
7+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Asum);
8+
9+
Asum::Asum(const Tensor &x, Tensor result) {
10+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, result);
11+
INFINICORE_GRAPH_OP_DISPATCH(result->device().getType(), x, result);
12+
}
13+
14+
void Asum::execute(const Tensor &x, Tensor result) {
15+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Asum, x, result);
1616
}
1717

18-
Tensor asum(Tensor x) {
18+
Tensor asum(const Tensor &x) {
1919
auto result = Tensor::empty({}, x->dtype(), x->device());
20-
asum_(result, x);
20+
asum_(x, result);
2121
return result;
2222
}
2323

24-
void asum_(Tensor result, Tensor x) {
25-
Asum::execute(result, x);
24+
void asum_(const Tensor &x, Tensor result) {
25+
Asum::execute(x, result);
2626
}
2727

2828
} // namespace infinicore::op
Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,50 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
31
#include "infinicore/ops/asum.hpp"
4-
#include "infinicore/ops/common/cache.hpp"
5-
#include <infiniop.h>
2+
3+
#include "../infiniop_impl.hpp"
64

75
namespace infinicore::op::asum_impl::infiniop {
86

9-
thread_local common::OpCache<size_t, infiniopAsumDescriptor_t> caches(
10-
100, // capacity
11-
[](infiniopAsumDescriptor_t &desc) {
12-
if (desc != nullptr) {
13-
INFINICORE_CHECK_ERROR(infiniopDestroyAsumDescriptor(desc));
14-
desc = nullptr;
15-
}
16-
});
7+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Asum, 100);
178

18-
void calculate(Tensor result, Tensor x) {
19-
size_t seed = hash_combine(result, x);
9+
struct PlannedMeta {
10+
std::shared_ptr<Descriptor> descriptor;
11+
graph::GraphTensor workspace, x, result;
12+
};
2013

21-
auto device_type = context::getDevice().getType();
22-
auto device_index = context::getDevice().getIndex();
14+
void *plan(const Tensor &x, Tensor result) {
15+
size_t seed = hash_combine(x, result);
2316

24-
auto &cache = caches.getCache(device_type, device_index);
17+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
18+
Descriptor, descriptor, Asum,
19+
seed,
20+
x->desc(), result->desc());
2521

26-
auto desc_opt = cache.get(seed);
27-
infiniopAsumDescriptor_t desc = nullptr;
22+
INFINIOP_WORKSPACE_TENSOR(workspace, Asum, descriptor);
2823

29-
if (!desc_opt) {
30-
INFINICORE_CHECK_ERROR(infiniopCreateAsumDescriptor(
31-
context::getInfiniopHandle(result->device()), &desc,
32-
x->desc(), result->desc()));
33-
cache.put(seed, desc);
34-
} else {
35-
desc = *desc_opt;
36-
}
24+
return new PlannedMeta{
25+
descriptor,
26+
graph::GraphTensor(workspace),
27+
graph::GraphTensor(x),
28+
graph::GraphTensor(result)};
29+
}
3730

38-
size_t workspace_size = 0;
39-
INFINICORE_CHECK_ERROR(infiniopGetAsumWorkspaceSize(desc, &workspace_size));
40-
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
31+
void run(void *planned_meta) {
32+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
4133

4234
INFINICORE_CHECK_ERROR(infiniopAsum(
43-
desc, workspace->data(), workspace_size,
44-
x->data(), result->data(), context::getStream()));
35+
planned->descriptor->desc,
36+
planned->workspace->data(),
37+
planned->workspace->numel(),
38+
planned->x->data(),
39+
planned->result->data(),
40+
context::getStream()));
41+
}
42+
43+
void cleanup(void **planned_meta_ptr) {
44+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
45+
*planned_meta_ptr = nullptr;
4546
}
4647

47-
static bool registered = []() {
48-
Asum::dispatcher().registerDevice({Device::Type::CPU,
49-
Device::Type::CAMBRICON,
50-
Device::Type::METAX},
51-
&calculate,
52-
false);
53-
return true;
54-
}();
48+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Asum, &plan, &run, &cleanup);
5549

5650
} // namespace infinicore::op::asum_impl::infiniop

0 commit comments

Comments
 (0)