Skip to content

Commit 7317e42

Browse files
committed
Update axpy, blas_dot, nrm2, rot and scal to use the graph framework
1 parent 20e0146 commit 7317e42

23 files changed

Lines changed: 257 additions & 296 deletions

File tree

include/infinicore/ops/axpy.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Axpy {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor);
11-
static void execute(Tensor alpha, Tensor x, Tensor y);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Axpy, const Tensor &, const Tensor &, Tensor);
1410

15-
void axpy_(Tensor alpha, Tensor x, Tensor y);
11+
void axpy_(const Tensor &alpha, const Tensor &x, Tensor y);
1612

1713
} // namespace infinicore::op
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class BlasDot {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor);
11-
static void execute(Tensor result, Tensor x, Tensor y);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(BlasDot, const Tensor &, const Tensor &, Tensor);
1410

15-
Tensor blas_dot(Tensor x, Tensor y);
16-
void blas_dot_(Tensor result, Tensor x, Tensor y);
11+
Tensor blas_dot(const Tensor &x, const Tensor &y);
12+
void blas_dot_(const Tensor &x, const Tensor &y, Tensor result);
1713

1814
} // namespace infinicore::op

include/infinicore/ops/nrm2.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Nrm2 {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor result, Tensor x);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Nrm2, const Tensor &, Tensor);
1410

15-
Tensor nrm2(Tensor x);
16-
void nrm2_(Tensor result, Tensor x);
11+
Tensor nrm2(const Tensor &x);
12+
void nrm2_(const Tensor &x, Tensor result);
1713

1814
} // namespace infinicore::op

include/infinicore/ops/rot.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Rot {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor, Tensor);
11-
static void execute(Tensor x, Tensor y, Tensor c, Tensor s);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Rot, Tensor, Tensor, const Tensor &, const Tensor &);
1410

15-
void rot_(Tensor x, Tensor y, Tensor c, Tensor s);
11+
void rot_(Tensor x, Tensor y, const Tensor &c, const Tensor &s);
1612

1713
} // namespace infinicore::op

include/infinicore/ops/scal.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Scal {
9-
public:
10-
using schema = void (*)(Tensor, Tensor);
11-
static void execute(Tensor alpha, Tensor x);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Scal, const Tensor &, Tensor);
1410

15-
void scal_(Tensor x, Tensor alpha);
11+
void scal_(const Tensor &alpha, Tensor x);
1612

1713
} // namespace infinicore::op

python/infinicore/ops/blas_dot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ def blas_dot(x: Tensor, y: Tensor, *, out=None):
66
if out is None:
77
return Tensor(_infinicore.blas_dot(x._underlying, y._underlying))
88

9-
_infinicore.blas_dot_(out._underlying, x._underlying, y._underlying)
9+
_infinicore.blas_dot_(x._underlying, y._underlying, out._underlying)
10+
1011
return out

python/infinicore/ops/nrm2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ def nrm2(x: Tensor, *, out=None):
66
if out is None:
77
return Tensor(_infinicore.nrm2(x._underlying))
88

9-
_infinicore.nrm2_(out._underlying, x._underlying)
9+
_infinicore.nrm2_(x._underlying, out._underlying)
10+
1011
return out

python/infinicore/ops/scal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33

44

55
def scal(x: Tensor, alpha: Tensor):
6-
_infinicore.scal_(x._underlying, alpha._underlying)
6+
_infinicore.scal_(alpha._underlying, x._underlying)
7+
78
return x

src/infinicore/ops/axpy/axpy.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44

55
namespace infinicore::op {
66

7-
common::OpDispatcher<Axpy::schema> &Axpy::dispatcher() {
8-
static common::OpDispatcher<Axpy::schema> dispatcher_;
9-
return dispatcher_;
10-
};
7+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Axpy);
118

12-
void Axpy::execute(Tensor alpha, Tensor x, Tensor y) {
9+
Axpy::Axpy(const Tensor &alpha, const Tensor &x, Tensor y) {
1310
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(alpha, x, y);
14-
infinicore::context::setDevice(y->device());
15-
dispatcher().lookup(y->device().getType())(alpha, x, y);
11+
INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), alpha, x, y);
1612
}
1713

18-
void axpy_(Tensor alpha, Tensor x, Tensor y) {
14+
void Axpy::execute(const Tensor &alpha, const Tensor &x, Tensor y) {
15+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Axpy, alpha, x, y);
16+
}
17+
18+
void axpy_(const Tensor &alpha, const Tensor &x, Tensor y) {
1919
Axpy::execute(alpha, x, y);
2020
}
2121

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,52 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
31
#include "infinicore/ops/axpy.hpp"
4-
#include "infinicore/ops/common/cache.hpp"
5-
#include <infiniop.h>
2+
3+
#include "../infiniop_impl.hpp"
64

75
namespace infinicore::op::axpy_impl::infiniop {
86

9-
thread_local common::OpCache<size_t, infiniopAxpyDescriptor_t> caches(
10-
100, // capacity
11-
[](infiniopAxpyDescriptor_t &desc) {
12-
if (desc != nullptr) {
13-
INFINICORE_CHECK_ERROR(infiniopDestroyAxpyDescriptor(desc));
14-
desc = nullptr;
15-
}
16-
});
7+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Axpy, 100);
178

18-
void calculate(Tensor alpha, Tensor x, Tensor y) {
19-
size_t seed = hash_combine(alpha, x, y);
9+
struct PlannedMeta {
10+
std::shared_ptr<Descriptor> descriptor;
11+
graph::GraphTensor workspace, alpha, x, y;
12+
};
2013

21-
auto device_type = context::getDevice().getType();
22-
auto device_index = context::getDevice().getIndex();
14+
void *plan(const Tensor &alpha, const Tensor &x, Tensor y) {
15+
size_t seed = hash_combine(y, alpha, x);
2316

24-
auto &cache = caches.getCache(device_type, device_index);
17+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
18+
Descriptor, descriptor, Axpy,
19+
seed,
20+
alpha->desc(), x->desc(), y->desc());
2521

26-
auto desc_opt = cache.get(seed);
27-
infiniopAxpyDescriptor_t desc = nullptr;
22+
INFINIOP_WORKSPACE_TENSOR(workspace, Axpy, descriptor);
2823

29-
if (!desc_opt) {
30-
INFINICORE_CHECK_ERROR(infiniopCreateAxpyDescriptor(
31-
context::getInfiniopHandle(y->device()), &desc,
32-
alpha->desc(), x->desc(), y->desc()));
33-
cache.put(seed, desc);
34-
} else {
35-
desc = *desc_opt;
36-
}
24+
return new PlannedMeta{
25+
descriptor,
26+
graph::GraphTensor(workspace),
27+
graph::GraphTensor(alpha),
28+
graph::GraphTensor(x),
29+
graph::GraphTensor(y)};
30+
}
3731

38-
size_t workspace_size = 0;
39-
INFINICORE_CHECK_ERROR(infiniopGetAxpyWorkspaceSize(desc, &workspace_size));
40-
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
32+
void run(void *planned_meta) {
33+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
4134

4235
INFINICORE_CHECK_ERROR(infiniopAxpy(
43-
desc, workspace->data(), workspace_size,
44-
alpha->data(), x->data(), y->data(), context::getStream()));
36+
planned->descriptor->desc,
37+
planned->workspace->data(),
38+
planned->workspace->numel(),
39+
planned->alpha->data(),
40+
planned->x->data(),
41+
planned->y->data(),
42+
context::getStream()));
43+
}
44+
45+
void cleanup(void **planned_meta_ptr) {
46+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
47+
*planned_meta_ptr = nullptr;
4548
}
4649

47-
static bool registered = []() {
48-
Axpy::dispatcher().registerDevice({Device::Type::CPU,
49-
Device::Type::CAMBRICON,
50-
Device::Type::METAX},
51-
&calculate,
52-
false);
53-
return true;
54-
}();
50+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Axpy, &plan, &run, &cleanup);
5551

5652
} // namespace infinicore::op::axpy_impl::infiniop

0 commit comments

Comments
 (0)