Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
39262e8
issue/497 - add dtype __eq__ and __hash__
wooway777 Oct 11, 2025
ab46327
issue/497 - simplified infinicore test functions
wooway777 Oct 11, 2025
93c74f0
issue/497 - improved test framework
wooway777 Oct 13, 2025
e8ff87c
issue/497 - add add interface to assist test
wooway777 Oct 13, 2025
212fe0d
issue/497 - generalized test framework based on add
wooway777 Oct 13, 2025
b0f83b4
issue/497 - support non-contiguous tensors in result comparison
wooway777 Oct 14, 2025
d4c0f08
issue/497 - temporarily fixed strided tensor creation
wooway777 Oct 14, 2025
bd69b50
issue/497 - rms norm interface
wooway777 Oct 14, 2025
2c0b996
issue/497 - now requires test function definition
wooway777 Oct 14, 2025
457a5a5
issue/497 - support mixed dtype
wooway777 Oct 15, 2025
0266fd5
issue/497 - initial rms norm test
wooway777 Oct 15, 2025
2ea0bad
issue/497 - unified in place and out of place tests
wooway777 Oct 15, 2025
33dfa4a
issue/497 - renamed src/infinicore/op
wooway777 Oct 15, 2025
2fa94f9
issue/497 - reduced comments
wooway777 Oct 15, 2025
7403618
issue/497 - attention
wooway777 Oct 16, 2025
95130a3
issue/497 - removed generic parameter mapping
wooway777 Oct 16, 2025
589b230
issue/497 - temporary attention test
wooway777 Oct 16, 2025
7199451
issue/497 - captitalize op name initial
wooway777 Oct 16, 2025
f844f76
issue/497 - add a script to run all op tests
wooway777 Oct 16, 2025
39aad83
issue/497 - fix comments
wooway777 Oct 16, 2025
623b3d5
issue/497 - simplified infinicore tensor creation from torch
wooway777 Oct 21, 2025
1b40ea0
issue/497 - support tensor init modes
wooway777 Oct 21, 2025
b02b749
issue/497 - support tensor from/to files
wooway777 Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include "op/matmul.hpp"
#include "op/ones.hpp"
#include "op/rearrange.hpp"
#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
17 changes: 17 additions & 0 deletions include/infinicore/ops/add.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Add {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor add(Tensor a, Tensor b);
void add_(Tensor c, Tensor a, Tensor b);
Tensor operator+(Tensor a, Tensor b);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Attention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, size_t);
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
} // namespace infinicore::op
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ namespace infinicore::op::common {
template <typename Fn>
class OpDispatcher {
public:
void registerDevice(Device::Type device_type, Fn fn, bool override_existing=true) {
if (table_[(size_t)device_type] == nullptr || override_existing){
void registerDevice(Device::Type device_type, Fn fn, bool override_existing = true) {
if (table_[(size_t)device_type] == nullptr || override_existing) {
table_[(size_t)device_type] = fn;
}
}

void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing=true) {
void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing = true) {
for (auto device_type : device_types) {
registerDevice(device_type, fn, override_existing);
}
}

void registerAll(Fn fn, bool override_existing=true) {
void registerAll(Fn fn, bool override_existing = true) {
for (size_t device_type = 0; device_type < static_cast<size_t>(Device::Type::COUNT); ++device_type) {
registerDevice((Device::Type)device_type, fn, override_existing);
}
Expand Down
File renamed without changes.
File renamed without changes.
16 changes: 16 additions & 0 deletions include/infinicore/ops/rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op
6 changes: 6 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
uint8,
)
from infinicore.ntops import use_ntops
from infinicore.ops.add import add
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange
from infinicore.ops.rms_norm import rms_norm
from infinicore.tensor import (
empty,
from_blob,
Expand Down Expand Up @@ -66,8 +69,11 @@
# `ntops` integration.
"use_ntops",
# Operations.
"add",
"attention",
"matmul",
"rearrange",
"rms_norm",
"empty",
"from_blob",
"ones",
Expand Down
25 changes: 23 additions & 2 deletions python/infinicore/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
class dtype:
def __init__(self, data_type):
"""An internal method. Please do not use this directly."""

self._underlying = data_type

def __repr__(self):
Expand All @@ -29,9 +28,31 @@ def __repr__(self):
_infinicore.DataType.C128: "complex128",
_infinicore.DataType.BF16: "bfloat16",
}

return f"infinicore.{repr_map[self._underlying]}"

def __eq__(self, other):
"""
Compare two dtype objects for equality.

Args:
other: The object to compare with

Returns:
bool: True if both objects are dtype instances with the same underlying data type
"""
if not isinstance(other, dtype):
return False
return self._underlying == other._underlying

def __hash__(self):
"""
Return a hash value for the dtype object.

Returns:
int: Hash value based on the underlying data type
"""
return hash(self._underlying)


float32 = dtype(_infinicore.DataType.F32)
float = float32
Expand Down
9 changes: 9 additions & 0 deletions python/infinicore/ops/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def add(input, other, *, out=None):
if out is None:
return Tensor(_infinicore.add(input._underlying, other._underlying))

_infinicore.add_(out._underlying, input._underlying, other._underlying)
26 changes: 26 additions & 0 deletions python/infinicore/ops/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def attention(q, k, v, k_cache, v_cache, pos, *, out=None):
if out is None:
return Tensor(
_infinicore.attention(
q._underlying,
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
pos,
)
)

_infinicore.attention_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
pos,
)
13 changes: 13 additions & 0 deletions python/infinicore/ops/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def rms_norm(input, weight, epsilon=1e-5, *, out=None):
if out is None:
return Tensor(
_infinicore.rms_norm(input._underlying, weight._underlying, epsilon)
)

_infinicore.rms_norm_(
out._underlying, input._underlying, weight._underlying, epsilon
)
24 changes: 24 additions & 0 deletions src/infinicore/ops/add/add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "infinicore/ops/add.hpp"

namespace infinicore::op {

common::OpDispatcher<Add::schema> &Add::dispatcher() {
static common::OpDispatcher<Add::schema> dispatcher_;
return dispatcher_;
};

void Add::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b);
}

Tensor add(Tensor a, Tensor b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
add_(c, a, b);
return c;
}

void add_(Tensor c, Tensor a, Tensor b) {
Add::execute(c, a, b);
}

} // namespace infinicore::op
52 changes: 52 additions & 0 deletions src/infinicore/ops/add/add_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::add_impl::infiniop {

thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches(
100, // capacity
[](infiniopAddDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopAddDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopAdd(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
}

static bool registered = []() {
Add::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::add_impl::infiniop
28 changes: 28 additions & 0 deletions src/infinicore/ops/attention/attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "infinicore/ops/attention.hpp"

namespace infinicore::op {

common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
static common::OpDispatcher<Attention::schema> dispatcher_;
return dispatcher_;
};

void Attention::execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
dispatcher().lookup(context::getDevice().getType())(out, q, k, v, k_cache, v_cache, pos);
}

Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t n_q_head = q->shape()[0];
size_t seq_len = q->shape()[1];
size_t head_dim = q->shape()[2];
Shape shape = {seq_len, n_q_head, head_dim};
auto out = Tensor::empty(shape, q->dtype(), q->device());
attention_(out, q, k, v, k_cache, v_cache, pos);
return out;
}

void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
}

} // namespace infinicore::op
54 changes: 54 additions & 0 deletions src/infinicore/ops/attention/attention_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/attention.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::attention_impl::infiniop {

thread_local common::OpCache<size_t, infiniopAttentionDescriptor_t> caches(
100, // capacity
[](infiniopAttentionDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAttentionDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopAttentionDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(), &desc,
out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAttentionWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k->data(), v->data(),
k_cache->data(), v_cache->data(), context::getStream()));
}

static bool registered = []() {
Attention::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::attention_impl::infiniop
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "infinicore/op/matmul.hpp"
#include "infinicore/ops/matmul.hpp"

namespace infinicore::op {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/op/common/cache.hpp"
#include "infinicore/op/matmul.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/matmul.hpp"
#include <infiniop.h>

namespace infinicore::op::matmul_impl::infiniop {
Expand All @@ -27,7 +27,9 @@ void calculate(Tensor c, Tensor a, Tensor b) {
infiniopGemmDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(context::getInfiniopHandle(), &desc, c->desc(), a->desc(), b->desc()));
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "infinicore/op/ones.hpp"
#include "infinicore/ops/ones.hpp"

namespace infinicore::op {

Expand Down
Loading