Skip to content

Commit 1b1dd5f

Browse files
authored
feat: add cuda relu infinilm (#678)
1 parent 56d87af commit 1b1dd5f

8 files changed

Lines changed: 353 additions & 0 deletions

File tree

src/base/relu_infinilm.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifndef INFINI_OPS_BASE_RELU_INFINILM_H_
2+
#define INFINI_OPS_BASE_RELU_INFINILM_H_
3+
4+
#include <cassert>
5+
6+
#include "operator.h"
7+
8+
namespace infini::ops {
9+
10+
class ReluInfinilm : public Operator<ReluInfinilm> {
11+
public:
12+
ReluInfinilm(const Tensor input, Tensor out)
13+
: input_shape_{input.shape()},
14+
input_strides_{input.strides()},
15+
input_type_{input.dtype()},
16+
out_shape_{out.shape()},
17+
out_strides_{out.strides()},
18+
out_type_{out.dtype()},
19+
output_size_{out.numel()},
20+
ndim_{out.ndim()},
21+
is_input_contiguous_{input.IsContiguous()},
22+
is_out_contiguous_{out.IsContiguous()},
23+
device_index_{out.device().index()} {
24+
assert(input_shape_ == out_shape_ &&
25+
"`ReluInfinilm` input and output shapes must match");
26+
assert(input_type_ == out_type_ &&
27+
"`ReluInfinilm` input and output dtypes must match");
28+
assert(!out.HasBroadcastDim() &&
29+
"`ReluInfinilm` output must not have broadcasted dimensions");
30+
}
31+
32+
virtual void operator()(const Tensor input, Tensor out) const = 0;
33+
34+
protected:
35+
Tensor::Shape input_shape_;
36+
37+
Tensor::Strides input_strides_;
38+
39+
DataType input_type_;
40+
41+
Tensor::Shape out_shape_;
42+
43+
Tensor::Strides out_strides_;
44+
45+
DataType out_type_;
46+
47+
Tensor::Size output_size_{0};
48+
49+
Tensor::Size ndim_{0};
50+
51+
bool is_input_contiguous_{false};
52+
53+
bool is_out_contiguous_{false};
54+
55+
int device_index_{0};
56+
};
57+
58+
} // namespace infini::ops
59+
60+
#endif
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_ILUVATAR_RELU_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/iluvatar/caster.cuh"
7+
#include "native/cuda/iluvatar/runtime_.h"
8+
#include "native/cuda/ops/relu_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<ReluInfinilm, Device::Type::kIluvatar>
14+
: public CudaReluInfinilm<Runtime<Device::Type::kIluvatar>> {
15+
public:
16+
using CudaReluInfinilm<Runtime<Device::Type::kIluvatar>>::CudaReluInfinilm;
17+
};
18+
19+
} // namespace infini::ops
20+
21+
#endif
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_METAX_RELU_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/metax/caster.cuh"
7+
#include "native/cuda/metax/runtime_.h"
8+
#include "native/cuda/ops/relu_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<ReluInfinilm, Device::Type::kMetax>
14+
: public CudaReluInfinilm<Runtime<Device::Type::kMetax>> {
15+
public:
16+
using CudaReluInfinilm<Runtime<Device::Type::kMetax>>::CudaReluInfinilm;
17+
};
18+
19+
} // namespace infini::ops
20+
21+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_MOORE_RELU_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/moore/caster.cuh"
7+
#include "native/cuda/moore/polyfills.cuh"
8+
#include "native/cuda/moore/runtime_.h"
9+
#include "native/cuda/ops/relu_infinilm/kernel.h"
10+
11+
namespace infini::ops {
12+
13+
template <>
14+
class Operator<ReluInfinilm, Device::Type::kMoore>
15+
: public CudaReluInfinilm<Runtime<Device::Type::kMoore>> {
16+
public:
17+
using CudaReluInfinilm<Runtime<Device::Type::kMoore>>::CudaReluInfinilm;
18+
};
19+
20+
} // namespace infini::ops
21+
22+
#endif
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_NVIDIA_RELU_INFINILM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
#include "native/cuda/nvidia/caster.cuh"
7+
#include "native/cuda/nvidia/runtime_.h"
8+
#include "native/cuda/ops/relu_infinilm/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
template <>
13+
class Operator<ReluInfinilm, Device::Type::kNvidia>
14+
: public CudaReluInfinilm<Runtime<Device::Type::kNvidia>> {
15+
public:
16+
using CudaReluInfinilm<Runtime<Device::Type::kNvidia>>::CudaReluInfinilm;
17+
};
18+
19+
} // namespace infini::ops
20+
21+
#endif
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_
2+
#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_CUH_
3+
4+
#include <cstddef>
5+
6+
#include "native/cuda/caster.cuh"
7+
#include "native/cuda/kernel_commons.cuh"
8+
9+
namespace infini::ops {
10+
11+
namespace {
12+
13+
template <Device::Type kDev, typename T>
14+
__device__ __forceinline__ T ReluInfinilmValue(T x) {
15+
const float v = Caster<kDev>::template Cast<float>(x);
16+
return Caster<kDev>::template Cast<T>(v > 0.0f ? v : 0.0f);
17+
}
18+
19+
template <Device::Type kDev>
20+
__device__ __forceinline__ double ReluInfinilmValue(double x) {
21+
return x > 0.0 ? x : 0.0;
22+
}
23+
24+
} // namespace
25+
26+
template <Device::Type kDev, typename T, unsigned int block_size>
27+
__global__ void ReluInfinilmKernel(T* __restrict__ out,
28+
const T* __restrict__ input,
29+
const size_t* __restrict__ out_shape,
30+
const size_t* __restrict__ input_shape,
31+
const ptrdiff_t* __restrict__ out_strides,
32+
const ptrdiff_t* __restrict__ input_strides,
33+
size_t output_size, size_t ndim,
34+
bool out_contiguous, bool input_contiguous) {
35+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
36+
37+
if (idx < output_size) {
38+
size_t out_idx =
39+
out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides);
40+
size_t input_idx =
41+
input_contiguous ? idx
42+
: IndexToOffset(idx, ndim, input_shape, input_strides);
43+
out[out_idx] = ReluInfinilmValue<kDev>(input[input_idx]);
44+
}
45+
}
46+
47+
} // namespace infini::ops
48+
49+
#endif
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#ifndef INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_
2+
#define INFINI_OPS_CUDA_RELU_INFINILM_KERNEL_H_
3+
4+
#include <algorithm>
5+
#include <cstddef>
6+
#include <cstring>
7+
#include <vector>
8+
9+
#include "base/relu_infinilm.h"
10+
#include "common/generic_utils.h"
11+
#include "data_type.h"
12+
#include "dispatcher.h"
13+
#include "native/cuda/kernel_commons.cuh"
14+
#include "native/cuda/ops/relu_infinilm/kernel.cuh"
15+
#include "native/cuda/runtime_utils.h"
16+
17+
namespace infini::ops {
18+
19+
template <typename Backend>
20+
class CudaReluInfinilm : public ReluInfinilm {
21+
public:
22+
CudaReluInfinilm(const Tensor input, Tensor out) : ReluInfinilm{input, out} {
23+
size_t shape_size = ndim_ * sizeof(*d_input_shape_);
24+
size_t strides_size = ndim_ * sizeof(*d_input_strides_);
25+
const size_t metadata_size = 2 * (shape_size + strides_size);
26+
std::vector<std::byte> metadata(metadata_size);
27+
28+
Backend::Malloc((void**)&d_metadata_, metadata_size);
29+
30+
size_t offset = 0;
31+
d_input_shape_ = reinterpret_cast<Tensor::Size*>(d_metadata_ + offset);
32+
std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size);
33+
offset += shape_size;
34+
35+
d_out_shape_ = reinterpret_cast<Tensor::Size*>(d_metadata_ + offset);
36+
std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size);
37+
offset += shape_size;
38+
39+
d_input_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
40+
std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size);
41+
offset += strides_size;
42+
43+
d_out_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
44+
std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size);
45+
46+
Backend::Memcpy(d_metadata_, metadata.data(), metadata_size,
47+
Backend::MemcpyHostToDevice);
48+
}
49+
50+
~CudaReluInfinilm() { Backend::Free(d_metadata_); }
51+
52+
void operator()(const Tensor input, Tensor out) const override {
53+
auto cuda_stream =
54+
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
55+
int block_size = std::min(
56+
RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(), 1024);
57+
dim3 block(std::min(static_cast<Tensor::Size>(block_size), output_size_));
58+
dim3 grid(utils::CeilDiv(output_size_, block.x));
59+
60+
DispatchFunc<AllFloatTypes, List<128, 256, 512, 1024>>(
61+
{static_cast<int64_t>(out_type_), block_size},
62+
[&](auto list_tag) {
63+
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
64+
constexpr int kBlockSize = ListGet<1>(list_tag);
65+
66+
ReluInfinilmKernel<Backend::kDeviceType, T, kBlockSize>
67+
<<<grid, block, 0, cuda_stream>>>(
68+
reinterpret_cast<T*>(out.data()),
69+
reinterpret_cast<const T*>(input.data()), d_out_shape_,
70+
d_input_shape_, d_out_strides_, d_input_strides_,
71+
output_size_, ndim_, is_out_contiguous_,
72+
is_input_contiguous_);
73+
},
74+
"CudaReluInfinilm::operator()");
75+
}
76+
77+
private:
78+
std::byte* d_metadata_{nullptr};
79+
80+
Tensor::Size* d_input_shape_{nullptr};
81+
82+
Tensor::Size* d_out_shape_{nullptr};
83+
84+
Tensor::Stride* d_input_strides_{nullptr};
85+
86+
Tensor::Stride* d_out_strides_{nullptr};
87+
};
88+
89+
} // namespace infini::ops
90+
91+
#endif

tests/test_relu_infinilm.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, empty_strided, get_stream, rand_strided
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"shape, input_strides, out_strides, inplace",
11+
(
12+
((1, 3), None, None, False),
13+
((1, 3), None, None, True),
14+
((3, 3), None, None, False),
15+
((3, 3), (5, 1), (5, 1), False),
16+
((32, 20, 512), None, None, False),
17+
((32, 20, 512), None, None, True),
18+
((33, 333, 333), None, None, False),
19+
((32, 256, 112, 112), None, None, False),
20+
((3, 3, 13, 9, 17), None, None, False),
21+
(
22+
(3, 3, 13, 9, 17),
23+
(19890, 6630, 510, 34, 1),
24+
(19890, 6630, 510, 34, 1),
25+
False,
26+
),
27+
),
28+
)
29+
@pytest.mark.parametrize(
30+
("dtype", "rtol", "atol"),
31+
(
32+
(torch.float32, 1e-7, 1e-7),
33+
(torch.float16, 1e-3, 1e-3),
34+
(torch.bfloat16, 1e-3, 1e-3),
35+
),
36+
)
37+
def test_relu_infinilm(
38+
shape, input_strides, out_strides, inplace, dtype, device, rtol, atol
39+
):
40+
input = rand_strided(shape, input_strides, dtype=dtype, device=device)
41+
input.mul_(2).sub_(1)
42+
out = (
43+
input
44+
if inplace
45+
else empty_strided(shape, out_strides, dtype=dtype, device=device)
46+
)
47+
48+
return Payload(
49+
_relu_infinilm,
50+
_torch_relu_infinilm,
51+
(input, out),
52+
{},
53+
rtol=rtol,
54+
atol=atol,
55+
)
56+
57+
58+
def _relu_infinilm(input, out):
59+
infini.ops.relu_infinilm(input, out, stream=get_stream(input.device))
60+
61+
return out
62+
63+
64+
def _torch_relu_infinilm(input, out):
65+
result = torch.nn.functional.relu(input)
66+
out.copy_(result)
67+
68+
return out

0 commit comments

Comments
 (0)