Skip to content

Commit e084b69

Browse files
issue/1126 fix softmax and conv2d
1 parent 1cee498 commit e084b69

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

src/infinicore/ops/conv2d/conv2d.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,41 @@ Tensor conv2d(Tensor input,
3737
const std::vector<size_t> &pads,
3838
const std::vector<size_t> &strides,
3939
const std::vector<size_t> &dilations) {
40-
// Output shape should be pre-computed by caller; allocate a conservative placeholder.
41-
// This helper is rarely used in performance-critical paths.
42-
Shape shape = input->shape();
43-
auto output = Tensor::empty(shape, input->dtype(), input->device());
40+
const auto &in_shape = input->shape(); // [N, C_in, H_in, W_in]
41+
const auto &w_shape = weight->shape(); // [C_out, C_in, kH, kW]
42+
43+
// -------------------------------
44+
// Extract dimensions
45+
// -------------------------------
46+
size_t N = in_shape[0];
47+
size_t C_in = in_shape[1];
48+
size_t H_in = in_shape[2];
49+
size_t W_in = in_shape[3];
50+
51+
size_t C_out = w_shape[0];
52+
size_t kH = w_shape[2];
53+
size_t kW = w_shape[3];
54+
55+
size_t pad_h = pads[0];
56+
size_t pad_w = pads[1];
57+
58+
size_t stride_h = strides[0];
59+
size_t stride_w = strides[1];
60+
61+
size_t dil_h = dilations[0];
62+
size_t dil_w = dilations[1];
63+
64+
auto calc_out = [](size_t in, size_t pad, size_t dilation, size_t kernel, size_t stride) {
65+
return (in + 2 * pad - dilation * (kernel - 1) - 1) / stride + 1;
66+
};
67+
size_t H_out = calc_out(H_in, pad_h, dil_h, kH, stride_h);
68+
size_t W_out = calc_out(W_in, pad_w, dil_w, kW, stride_w);
69+
if ((int64_t)H_out <= 0 || (int64_t)W_out <= 0) {
70+
throw std::runtime_error("Invalid conv2d output shape (negative or zero)");
71+
}
72+
Shape out_shape = {N, C_out, H_out, W_out};
73+
74+
auto output = Tensor::empty(out_shape, input->dtype(), input->device());
4475
conv2d_(output, input, weight, bias, pads, strides, dilations);
4576
return output;
4677
}

src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,33 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
8181
othersize, dimsize, stride);
8282
}
8383

84+
} else if (dtype == INFINI_DTYPE_BF16) {
85+
if (dimsize > 1024) {
86+
blockSoftmax<cuda_bfloat16, BLOCK_SIZE>
87+
<<<num_blocks, BLOCK_SIZE, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
88+
dimsize, stride);
89+
} else if (dimsize > 31) {
90+
constexpr unsigned int BLOCK_SIZE_x = 32;
91+
constexpr unsigned int BLOCK_SIZE_y = 32;
92+
constexpr int numPerThreadx = 32;
93+
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
94+
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
95+
dim3 grid_dim(num_block_x, 1, 1);
96+
warpSoftmax<cuda_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
97+
<<<grid_dim, block_dim, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
98+
othersize, dimsize, stride);
99+
} else {
100+
constexpr unsigned int BLOCK_SIZE_x = 16;
101+
constexpr unsigned int BLOCK_SIZE_y = 32;
102+
constexpr int numPerThreadx = 2;
103+
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
104+
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
105+
dim3 grid_dim(num_block_x, 1, 1);
106+
warpSoftmax<cuda_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
107+
<<<grid_dim, block_dim, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
108+
othersize, dimsize, stride);
109+
}
110+
84111
} else if (dtype == INFINI_DTYPE_F32) {
85112
if (dimsize > 1024) {
86113
blockSoftmax<float, BLOCK_SIZE>

0 commit comments

Comments
 (0)