Skip to content

Commit 74564bd

Browse files
committed
Add InfiniCore rot, rotg, rotm and rotmg wrappers
1 parent 8b57466 commit 74564bd

27 files changed

Lines changed: 1113 additions & 0 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
#include "ops/relu.hpp"
4949
#include "ops/rms_norm.hpp"
5050
#include "ops/rope.hpp"
51+
#include "ops/rot.hpp"
52+
#include "ops/rotg.hpp"
53+
#include "ops/rotm.hpp"
54+
#include "ops/rotmg.hpp"
5155
#include "ops/scal.hpp"
5256
#include "ops/silu.hpp"
5357
#include "ops/silu_and_mul.hpp"

include/infinicore/ops/rot.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Rot {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor);
11+
static void execute(Tensor x, Tensor y, Tensor c, Tensor s);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
void rot_(Tensor x, Tensor y, Tensor c, Tensor s);
16+
17+
} // namespace infinicore::op

include/infinicore/ops/rotg.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Rotg {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor);
11+
static void execute(Tensor x, Tensor y, Tensor c, Tensor s);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
void rotg_(Tensor x, Tensor y, Tensor c, Tensor s);
16+
17+
} // namespace infinicore::op

include/infinicore/ops/rotm.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Rotm {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor);
11+
static void execute(Tensor x, Tensor y, Tensor param);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
void rotm_(Tensor x, Tensor y, Tensor param);
16+
17+
} // namespace infinicore::op

include/infinicore/ops/rotmg.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Rotmg {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
11+
static void execute(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
void rotmg_(Tensor d1, Tensor d2, Tensor x1, Tensor y1, Tensor param);
16+
17+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@
114114
from infinicore.ops.paged_caching import paged_caching
115115
from infinicore.ops.rearrange import rearrange
116116
from infinicore.ops.reciprocal import reciprocal
117+
from infinicore.ops.rot import rot
118+
from infinicore.ops.rotg import rotg
119+
from infinicore.ops.rotm import rotm
120+
from infinicore.ops.rotmg import rotmg
117121
from infinicore.ops.scal import scal
118122
from infinicore.ops.scatter import scatter
119123
from infinicore.ops.sinh import sinh
@@ -247,6 +251,10 @@
247251
"float_power",
248252
"flipud",
249253
"scatter",
254+
"rot",
255+
"rotg",
256+
"rotm",
257+
"rotmg",
250258
"scal",
251259
"logcumsumexp",
252260
"logical_not",

python/infinicore/ops/rot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def rot(x: Tensor, y: Tensor, c: Tensor, s: Tensor):
6+
_infinicore.rot_(x._underlying, y._underlying, c._underlying, s._underlying)
7+
return x, y

python/infinicore/ops/rotg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def rotg(x: Tensor, y: Tensor, c: Tensor, s: Tensor):
6+
_infinicore.rotg_(x._underlying, y._underlying, c._underlying, s._underlying)
7+
return x, y, c, s

python/infinicore/ops/rotm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def rotm(x: Tensor, y: Tensor, param: Tensor):
6+
_infinicore.rotm_(x._underlying, y._underlying, param._underlying)
7+
return x, y

python/infinicore/ops/rotmg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def rotmg(d1: Tensor, d2: Tensor, x1: Tensor, y1: Tensor, param: Tensor):
6+
_infinicore.rotmg_(
7+
d1._underlying,
8+
d2._underlying,
9+
x1._underlying,
10+
y1._underlying,
11+
param._underlying,
12+
)
13+
return d1, d2, x1, param

0 commit comments

Comments
 (0)