Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/rearrange.hpp"
Expand Down
16 changes: 16 additions & 0 deletions include/infinicore/ops/causal_softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class CausalSoftmax {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor causal_softmax(Tensor input);
void causal_softmax_(Tensor output, Tensor input);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from infinicore.ntops import use_ntops
from infinicore.ops.add import add
from infinicore.ops.attention import attention
from infinicore.ops.causal_softmax import causal_softmax
from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange
from infinicore.ops.rms_norm import rms_norm
Expand Down Expand Up @@ -71,6 +72,7 @@
# Operations.
"add",
"attention",
"causal_softmax",
"matmul",
"rearrange",
"rms_norm",
Expand Down
9 changes: 9 additions & 0 deletions python/infinicore/ops/causal_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def causal_softmax(input, *, out=None):
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))

_infinicore.causal_softmax_(out._underlying, input._underlying)
32 changes: 32 additions & 0 deletions src/infinicore/ops/causal_softmax/causal_softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "infinicore/ops/causal_softmax.hpp"
#include <stdexcept>

namespace infinicore::op {

common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
static common::OpDispatcher<CausalSoftmax::schema> dispatcher_;
return dispatcher_;
};

void CausalSoftmax::execute(Tensor output, Tensor input) {
auto device_type = context::getDevice().getType();
auto func = dispatcher().lookup(device_type);

if (func == nullptr) {
throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}

func(output, input);
}

Tensor causal_softmax(Tensor input) {
Shape shape = input->shape();
auto output = Tensor::empty(shape, input->dtype(), input->device());
causal_softmax_(output, input);
return output;
}

void causal_softmax_(Tensor output, Tensor input) {
CausalSoftmax::execute(output, input);
}
} // namespace infinicore::op
52 changes: 52 additions & 0 deletions src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::causal_softmax_impl::infiniop {

thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
100, // capacity
[](infiniopCausalSoftmaxDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopCausalSoftmaxDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopCausalSoftmax(
desc, workspace->data(), workspace_size,
output->data(), input->data(), context::getStream()));
}

static bool registered = []() {
CausalSoftmax::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::causal_softmax_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
Expand All @@ -15,6 +16,7 @@ namespace infinicore::ops {
inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_causal_softmax(m);
bind_matmul(m);
bind_rearrange(m);
bind_rms_norm(m);
Expand Down
24 changes: 24 additions & 0 deletions src/infinicore/pybind11/ops/causal_softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/causal_softmax.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_causal_softmax(py::module &m) {
m.def("causal_softmax",
&op::causal_softmax,
py::arg("input"),
R"doc(Causal softmax activation function.)doc");

m.def("causal_softmax_",
&op::causal_softmax_,
py::arg("output"),
py::arg("input"),
R"doc(In-place causal softmax activation function.)doc");
}

} // namespace infinicore::ops
114 changes: 114 additions & 0 deletions test/infinicore/ops/causal_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner

# ==============================================================================
# Operator-specific configuration
# ==============================================================================

# Test cases format: (operation_mode, shape, input_strides, output_strides)
# Causal softmax is a single-input function that applies causal masking before softmax
_TEST_CASES_DATA = [
# Basic 2D causal softmax
(TestCase.BOTH, (3, 3), None, None),
(TestCase.BOTH, (32, 512), None, None),
# Strided tensors
(TestCase.BOTH, (32, 512), (1024, 1), (1024, 1)),
# 3D causal softmax
(TestCase.BOTH, (32, 5, 5), None, None),
(TestCase.BOTH, (32, 20, 512), None, None),
(TestCase.BOTH, (32, 20, 512), (20480, 512, 1), None),
(TestCase.BOTH, (28, 15, 15), None, None),
]


def parse_test_cases(data):
"""
Parse causal_softmax test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode = data[0]
shape = data[1]
input_strides = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None

# Create input specifications
inputs = []

# Tensor input
if input_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))

# Output tensor
if output_strides is not None:
output = TensorSpec.from_strided_tensor(shape, output_strides)
else:
output = TensorSpec.from_tensor(shape)

return TestCase(operation_mode, inputs, output)


# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]

# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]

# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
infinicore.float32: {"atol": 3e-5, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
}


class OpTest(BaseOperatorTest):
"""CausalSoftmax test with simplified test case parsing"""

def __init__(self):
super().__init__("CausalSoftmax")

def get_test_cases(self):
return _TEST_CASES

def get_tensor_dtypes(self):
return _TENSOR_DTYPES

def get_tolerance_map(self):
return _TOLERANCE_MAP

def torch_operator(self, input, out=None, **kwargs):
# Causal softmax implementation: apply causal mask then softmax
dtype = input.dtype

# Create causal mask
mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1])
masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32))

result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype)

if out is not None:
out.copy_(result)
return out
return result

def infinicore_operator(self, input, out=None, **kwargs):
return infinicore.causal_softmax(input, out=out)


def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion xmake/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ target("infinicore-test")
add_files(os.projectdir().."/src/infinicore/context/*.cc")
add_files(os.projectdir().."/src/infinicore/context/*/*.cc")
add_files(os.projectdir().."/src/infinicore/tensor/*.cc")
add_files(os.projectdir().."/src/infinicore/op/*/*.cc")
add_files(os.projectdir().."/src/infinicore/ops/*/*.cc")

add_files(os.projectdir().."/src/infinicore-test/*.cc")

Expand Down