Skip to content

Commit e74339c

Browse files
author
zhangyue
committed
feat(ascend): add scaled softmax operator
1 parent 64751ea commit e74339c

3 files changed

Lines changed: 245 additions & 0 deletions

File tree

src/base/scaled_softmax.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef INFINI_OPS_BASE_SCALED_SOFTMAX_H_
2+
#define INFINI_OPS_BASE_SCALED_SOFTMAX_H_
3+
4+
#include <cassert>
5+
#include <cmath>
6+
#include <cstddef>
7+
8+
#include "data_type.h"
9+
#include "operator.h"
10+
#include "tensor.h"
11+
12+
namespace infini::ops {
13+
14+
class ScaledSoftmax : public Operator<ScaledSoftmax> {
15+
public:
16+
ScaledSoftmax(const Tensor input, double scale, Tensor out)
17+
: scale_{scale},
18+
batch_size_{input.size(0)},
19+
vocab_size_{input.size(1)},
20+
dtype_{input.dtype()},
21+
input_strides_{input.strides()},
22+
out_strides_{out.strides()} {
23+
assert(input.ndim() == 2 &&
24+
"`ScaledSoftmax` currently supports 2D `[batch, vocab]` input");
25+
assert(input.shape() == out.shape() &&
26+
"`ScaledSoftmax` requires `input` and `out` to have the same shape");
27+
assert(input.dtype() == out.dtype() &&
28+
"`ScaledSoftmax` requires `input` and `out` to have the same dtype");
29+
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
30+
dtype_ == DataType::kFloat32 || dtype_ == DataType::kFloat64) &&
31+
"`ScaledSoftmax` requires a floating point dtype");
32+
assert(std::isfinite(scale_) &&
33+
"`ScaledSoftmax` requires a finite `scale`");
34+
}
35+
36+
virtual void operator()(const Tensor input, double scale,
37+
Tensor out) const = 0;
38+
39+
protected:
40+
double scale_{1.0};
41+
42+
Tensor::Size batch_size_{0};
43+
44+
Tensor::Size vocab_size_{0};
45+
46+
DataType dtype_;
47+
48+
Tensor::Strides input_strides_;
49+
50+
Tensor::Strides out_strides_;
51+
};
52+
53+
} // namespace infini::ops
54+
55+
#endif // INFINI_OPS_BASE_SCALED_SOFTMAX_H_
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#ifndef INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_
3+
4+
#include <cmath>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnn_mul.h"
9+
#include "aclnn_softmax.h"
10+
#include "base/scaled_softmax.h"
11+
#include "data_type.h"
12+
#include "native/ascend/common.h"
13+
#include "native/ascend/workspace_pool_.h"
14+
#include "operator.h"
15+
16+
namespace infini::ops {
17+
18+
template <>
19+
class Operator<ScaledSoftmax, Device::Type::kAscend> : public ScaledSoftmax {
20+
public:
21+
Operator(const Tensor input, double scale, Tensor out)
22+
: ScaledSoftmax(input, scale, out),
23+
in_cache_(input),
24+
out_cache_(out),
25+
temp_cache_(input),
26+
scale_storage_(static_cast<float>(scale)),
27+
needs_scale_(std::fabs(scale - 1.0) > 1e-6) {
28+
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
29+
dtype_ == DataType::kFloat32) &&
30+
"`ScaledSoftmax` Ascend path requires float16, bfloat16, or "
31+
"float32 input");
32+
assert(input.IsContiguous() &&
33+
"`ScaledSoftmax` Ascend path requires contiguous input");
34+
assert(out.IsContiguous() &&
35+
"`ScaledSoftmax` Ascend path requires contiguous output");
36+
37+
temp_size_ = input.numel() * kDataTypeToSize.at(dtype_);
38+
scale_scalar_ = aclCreateScalar(&scale_storage_, ACL_FLOAT);
39+
}
40+
41+
~Operator() {
42+
if (!ascend::IsAclRuntimeAlive()) return;
43+
44+
in_cache_.release();
45+
out_cache_.release();
46+
temp_cache_.release();
47+
48+
if (scale_scalar_) aclDestroyScalar(scale_scalar_);
49+
}
50+
51+
void operator()(const Tensor input, double scale, Tensor out) const override {
52+
assert(scale == scale_ &&
53+
"`ScaledSoftmax` scale changed after descriptor creation");
54+
55+
auto stream = static_cast<aclrtStream>(stream_);
56+
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
57+
auto t_out = out_cache_.get(out.data());
58+
aclTensor* t_softmax_in = t_in;
59+
void* softmax_in_data = const_cast<void*>(input.data());
60+
61+
if (needs_scale_) {
62+
auto& temp =
63+
ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp");
64+
auto t_temp = temp_cache_.get(temp.buf);
65+
66+
if (!muls_exec_) {
67+
aclnnMulsGetWorkspaceSize(t_in, scale_scalar_, t_temp, &muls_ws_,
68+
&muls_exec_);
69+
aclSetAclOpExecutorRepeatable(muls_exec_);
70+
} else {
71+
aclSetInputTensorAddr(muls_exec_, 0, t_in,
72+
const_cast<void*>(input.data()));
73+
aclSetOutputTensorAddr(muls_exec_, 0, t_temp, temp.buf);
74+
}
75+
76+
auto& muls_arena = ascend::GetWorkspacePool().Ensure(stream, muls_ws_);
77+
aclnnMuls(muls_arena.buf, muls_ws_, muls_exec_, stream);
78+
79+
t_softmax_in = t_temp;
80+
softmax_in_data = temp.buf;
81+
}
82+
83+
if (!softmax_exec_) {
84+
constexpr int64_t kLastDim = -1;
85+
aclnnSoftmaxGetWorkspaceSize(t_softmax_in, kLastDim, t_out, &softmax_ws_,
86+
&softmax_exec_);
87+
aclSetAclOpExecutorRepeatable(softmax_exec_);
88+
} else {
89+
aclSetInputTensorAddr(softmax_exec_, 0, t_softmax_in, softmax_in_data);
90+
aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data());
91+
}
92+
93+
auto& softmax_arena =
94+
ascend::GetWorkspacePool().Ensure(stream, softmax_ws_);
95+
aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream);
96+
}
97+
98+
private:
99+
mutable ascend::AclTensorCache in_cache_;
100+
101+
mutable ascend::AclTensorCache out_cache_;
102+
103+
mutable ascend::AclTensorCache temp_cache_;
104+
105+
float scale_storage_{1.0f};
106+
107+
aclScalar* scale_scalar_ = nullptr;
108+
109+
bool needs_scale_{false};
110+
111+
uint64_t temp_size_{0};
112+
113+
mutable aclOpExecutor* muls_exec_ = nullptr;
114+
115+
mutable uint64_t muls_ws_ = 0;
116+
117+
mutable aclOpExecutor* softmax_exec_ = nullptr;
118+
119+
mutable uint64_t softmax_ws_ = 0;
120+
};
121+
122+
} // namespace infini::ops
123+
124+
#endif // INFINI_OPS_ASCEND_SCALED_SOFTMAX_KERNEL_H_

tests/test_scaled_softmax.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, empty_strided, get_stream, randn_strided
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"shape",
11+
(
12+
(1, 7),
13+
(3, 11),
14+
(16, 512),
15+
),
16+
)
17+
@pytest.mark.parametrize("scale", (1.0, 0.5, 1.7))
18+
@pytest.mark.parametrize(
19+
("dtype", "rtol", "atol"),
20+
(
21+
(torch.float32, 1e-5, 1e-5),
22+
(torch.float16, 1e-2, 1e-2),
23+
(torch.bfloat16, 1e-2, 1e-2),
24+
),
25+
)
26+
def test_scaled_softmax(
27+
shape,
28+
scale,
29+
dtype,
30+
device,
31+
implementation_index,
32+
rtol,
33+
atol,
34+
):
35+
input_tensor = randn_strided(shape, None, dtype=dtype, device=device)
36+
out = empty_strided(shape, None, dtype=dtype, device=device)
37+
38+
return Payload(
39+
_scaled_softmax,
40+
_torch_scaled_softmax,
41+
(input_tensor, out),
42+
{"scale": scale, "implementation_index": implementation_index},
43+
rtol=rtol,
44+
atol=atol,
45+
)
46+
47+
48+
def _scaled_softmax(input_tensor, out, *, scale, implementation_index):
49+
infini.ops.scaled_softmax(
50+
input_tensor,
51+
scale,
52+
out,
53+
stream=get_stream(input_tensor.device),
54+
implementation_index=implementation_index,
55+
)
56+
57+
return out
58+
59+
60+
def _torch_scaled_softmax(input_tensor, out, *, scale, implementation_index):
61+
del implementation_index
62+
63+
result = torch.nn.functional.softmax(input_tensor.to(torch.float32) * scale, dim=-1)
64+
out.copy_(result.to(input_tensor.dtype))
65+
66+
return out

0 commit comments

Comments
 (0)