Skip to content

Commit 89b42a8

Browse files
authored
Merge pull request #533 from gongchensu/feature/add_causalSoftmax_python_api
Add causalSoftmax operator Python interface and tests.
2 parents 4ee9109 + ecb938c commit 89b42a8

9 files changed

Lines changed: 252 additions & 0 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "ops/add.hpp"
44
#include "ops/attention.hpp"
5+
#include "ops/causal_softmax.hpp"
56
#include "ops/matmul.hpp"
67
#include "ops/ones.hpp"
78
#include "ops/rearrange.hpp"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class CausalSoftmax {
8+
public:
9+
using schema = void (*)(Tensor, Tensor);
10+
static void execute(Tensor output, Tensor input);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor causal_softmax(Tensor input);
15+
void causal_softmax_(Tensor output, Tensor input);
16+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from infinicore.ntops import use_ntops
2828
from infinicore.ops.add import add
2929
from infinicore.ops.attention import attention
30+
from infinicore.ops.causal_softmax import causal_softmax
3031
from infinicore.ops.matmul import matmul
3132
from infinicore.ops.rearrange import rearrange
3233
from infinicore.ops.rms_norm import rms_norm
@@ -73,6 +74,7 @@
7374
# Operations.
7475
"add",
7576
"attention",
77+
"causal_softmax",
7678
"matmul",
7779
"rearrange",
7880
"rms_norm",
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def causal_softmax(input, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.causal_softmax(input._underlying))
8+
9+
_infinicore.causal_softmax_(out._underlying, input._underlying)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "infinicore/ops/causal_softmax.hpp"
2+
#include <stdexcept>
3+
4+
namespace infinicore::op {
5+
6+
common::OpDispatcher<CausalSoftmax::schema> &CausalSoftmax::dispatcher() {
7+
static common::OpDispatcher<CausalSoftmax::schema> dispatcher_;
8+
return dispatcher_;
9+
};
10+
11+
void CausalSoftmax::execute(Tensor output, Tensor input) {
12+
auto device_type = context::getDevice().getType();
13+
auto func = dispatcher().lookup(device_type);
14+
15+
if (func == nullptr) {
16+
throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
17+
}
18+
19+
func(output, input);
20+
}
21+
22+
Tensor causal_softmax(Tensor input) {
23+
Shape shape = input->shape();
24+
auto output = Tensor::empty(shape, input->dtype(), input->device());
25+
causal_softmax_(output, input);
26+
return output;
27+
}
28+
29+
void causal_softmax_(Tensor output, Tensor input) {
30+
CausalSoftmax::execute(output, input);
31+
}
32+
} // namespace infinicore::op
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/causal_softmax.hpp"
4+
#include "infinicore/ops/common/cache.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::causal_softmax_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopCausalSoftmaxDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor output, Tensor input) {
19+
size_t seed = hash_combine(output, input);
20+
21+
auto device_type = context::getDevice().getType();
22+
auto device_index = context::getDevice().getIndex();
23+
24+
auto &cache = caches.getCache(device_type, device_index);
25+
26+
auto desc_opt = cache.get(seed);
27+
infiniopCausalSoftmaxDescriptor_t desc = nullptr;
28+
29+
if (!desc_opt) {
30+
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
31+
context::getInfiniopHandle(), &desc,
32+
output->desc(), input->desc()));
33+
cache.put(seed, desc);
34+
} else {
35+
desc = *desc_opt;
36+
}
37+
38+
size_t workspace_size = 0;
39+
INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
40+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
41+
42+
INFINICORE_CHECK_ERROR(infiniopCausalSoftmax(
43+
desc, workspace->data(), workspace_size,
44+
output->data(), input->data(), context::getStream()));
45+
}
46+
47+
static bool registered = []() {
48+
CausalSoftmax::dispatcher().registerAll(&calculate, false);
49+
return true;
50+
}();
51+
52+
} // namespace infinicore::op::causal_softmax_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
7+
#include "ops/causal_softmax.hpp"
78
#include "ops/matmul.hpp"
89
#include "ops/rearrange.hpp"
910
#include "ops/rms_norm.hpp"
@@ -17,6 +18,7 @@ namespace infinicore::ops {
1718
inline void bind(py::module &m) {
1819
bind_add(m);
1920
bind_attention(m);
21+
bind_causal_softmax(m);
2022
bind_matmul(m);
2123
bind_rearrange(m);
2224
bind_rms_norm(m);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/causal_softmax.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_causal_softmax(py::module &m) {
12+
m.def("causal_softmax",
13+
&op::causal_softmax,
14+
py::arg("input"),
15+
R"doc(Causal softmax activation function.)doc");
16+
17+
m.def("causal_softmax_",
18+
&op::causal_softmax_,
19+
py::arg("output"),
20+
py::arg("input"),
21+
R"doc(In-place causal softmax activation function.)doc");
22+
}
23+
24+
} // namespace infinicore::ops
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
11+
# ==============================================================================
12+
# Operator-specific configuration
13+
# ==============================================================================
14+
15+
# Test cases format: (operation_mode, shape, input_strides, output_strides)
16+
# Causal softmax is a single-input function that applies causal masking before softmax
17+
_TEST_CASES_DATA = [
18+
# Basic 2D causal softmax
19+
(TestCase.BOTH, (3, 3), None, None),
20+
(TestCase.BOTH, (32, 512), None, None),
21+
# Strided tensors
22+
(TestCase.BOTH, (32, 512), (1024, 1), (1024, 1)),
23+
# 3D causal softmax
24+
(TestCase.BOTH, (32, 5, 5), None, None),
25+
(TestCase.BOTH, (32, 20, 512), None, None),
26+
(TestCase.BOTH, (32, 20, 512), (20480, 512, 1), None),
27+
(TestCase.BOTH, (28, 15, 15), None, None),
28+
]
29+
30+
31+
def parse_test_cases(data):
32+
"""
33+
Parse causal_softmax test case data according to format:
34+
(operation_mode, shape, input_strides, output_strides)
35+
"""
36+
operation_mode = data[0]
37+
shape = data[1]
38+
input_strides = data[2] if len(data) > 2 else None
39+
output_strides = data[3] if len(data) > 3 else None
40+
41+
# Create input specifications
42+
inputs = []
43+
44+
# Tensor input
45+
if input_strides is not None:
46+
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
47+
else:
48+
inputs.append(TensorSpec.from_tensor(shape))
49+
50+
# Output tensor
51+
if output_strides is not None:
52+
output = TensorSpec.from_strided_tensor(shape, output_strides)
53+
else:
54+
output = TensorSpec.from_tensor(shape)
55+
56+
return TestCase(operation_mode, inputs, output)
57+
58+
59+
# Parse test cases
60+
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
61+
62+
# Data types
63+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
64+
65+
# Tolerance
66+
_TOLERANCE_MAP = {
67+
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
68+
infinicore.float32: {"atol": 3e-5, "rtol": 1e-5},
69+
infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2},
70+
}
71+
72+
73+
class OpTest(BaseOperatorTest):
74+
"""CausalSoftmax test with simplified test case parsing"""
75+
76+
def __init__(self):
77+
super().__init__("CausalSoftmax")
78+
79+
def get_test_cases(self):
80+
return _TEST_CASES
81+
82+
def get_tensor_dtypes(self):
83+
return _TENSOR_DTYPES
84+
85+
def get_tolerance_map(self):
86+
return _TOLERANCE_MAP
87+
88+
def torch_operator(self, input, out=None, **kwargs):
89+
# Causal softmax implementation: apply causal mask then softmax
90+
dtype = input.dtype
91+
92+
# Create causal mask
93+
mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1])
94+
masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32))
95+
96+
result = torch.nn.functional.softmax(masked, dim=-1, dtype=dtype)
97+
98+
if out is not None:
99+
out.copy_(result)
100+
return out
101+
return result
102+
103+
def infinicore_operator(self, input, out=None, **kwargs):
104+
return infinicore.causal_softmax(input, out=out)
105+
106+
107+
def main():
108+
"""Main entry point"""
109+
runner = GenericTestRunner(OpTest)
110+
runner.run_and_exit()
111+
112+
113+
if __name__ == "__main__":
114+
main()

0 commit comments

Comments
 (0)