Skip to content

Commit 7b79837

Browse files
committed
issue/497 - add add interface to assist test
1 parent ed56bd7 commit 7b79837

22 files changed

Lines changed: 162 additions & 27 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3-
#include "op/matmul.hpp"
4-
#include "op/ones.hpp"
5-
#include "op/rearrange.hpp"
3+
#include "ops/add.hpp"
4+
#include "ops/matmul.hpp"
5+
#include "ops/ones.hpp"
6+
#include "ops/rearrange.hpp"

include/infinicore/ops/add.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Add {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor c, Tensor a, Tensor b);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor add(Tensor a, Tensor b);
15+
void add_(Tensor c, Tensor a, Tensor b);
16+
Tensor operator+(Tensor a, Tensor b);
17+
} // namespace infinicore::op

include/infinicore/op/common/dispatcher.hpp renamed to include/infinicore/ops/common/dispatcher.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ namespace infinicore::op::common {
88
template <typename Fn>
99
class OpDispatcher {
1010
public:
11-
void registerDevice(Device::Type device_type, Fn fn, bool override_existing=true) {
12-
if (table_[(size_t)device_type] == nullptr || override_existing){
11+
void registerDevice(Device::Type device_type, Fn fn, bool override_existing = true) {
12+
if (table_[(size_t)device_type] == nullptr || override_existing) {
1313
table_[(size_t)device_type] = fn;
1414
}
1515
}
1616

17-
void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing=true) {
17+
void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing = true) {
1818
for (auto device_type : device_types) {
1919
registerDevice(device_type, fn, override_existing);
2020
}
2121
}
2222

23-
void registerAll(Fn fn, bool override_existing=true) {
23+
void registerAll(Fn fn, bool override_existing = true) {
2424
for (size_t device_type = 0; device_type < static_cast<size_t>(Device::Type::COUNT); ++device_type) {
2525
registerDevice((Device::Type)device_type, fn, override_existing);
2626
}

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
short,
2525
uint8,
2626
)
27+
from infinicore.ops.add import add
2728
from infinicore.ops.matmul import matmul
2829
from infinicore.ops.rearrange import rearrange
2930
from infinicore.tensor import (
@@ -63,6 +64,7 @@
6364
"short",
6465
"uint8",
6566
# Operations.
67+
"add",
6668
"matmul",
6769
"rearrange",
6870
"empty",

python/infinicore/ops/add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def add(input, other, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.add(input._underlying, other._underlying))
8+
9+
_infinicore.add_(out._underlying, input._underlying, other._underlying)

0 commit comments

Comments
 (0)