@@ -37,10 +37,41 @@ Tensor conv2d(Tensor input,
3737 const std::vector<size_t > &pads,
3838 const std::vector<size_t > &strides,
3939 const std::vector<size_t > &dilations) {
40- // Output shape should be pre-computed by caller; allocate a conservative placeholder.
41- // This helper is rarely used in performance-critical paths.
42- Shape shape = input->shape ();
43- auto output = Tensor::empty (shape, input->dtype (), input->device ());
40+ const auto &in_shape = input->shape (); // [N, C_in, H_in, W_in]
41+ const auto &w_shape = weight->shape (); // [C_out, C_in, kH, kW]
42+
43+ // -------------------------------
44+ // Extract dimensions
45+ // -------------------------------
46+ size_t N = in_shape[0 ];
47+ size_t C_in = in_shape[1 ];
48+ size_t H_in = in_shape[2 ];
49+ size_t W_in = in_shape[3 ];
50+
51+ size_t C_out = w_shape[0 ];
52+ size_t kH = w_shape[2 ];
53+ size_t kW = w_shape[3 ];
54+
55+ size_t pad_h = pads[0 ];
56+ size_t pad_w = pads[1 ];
57+
58+ size_t stride_h = strides[0 ];
59+ size_t stride_w = strides[1 ];
60+
61+ size_t dil_h = dilations[0 ];
62+ size_t dil_w = dilations[1 ];
63+
64+ auto calc_out = [](size_t in, size_t pad, size_t dilation, size_t kernel, size_t stride) {
65+ return (in + 2 * pad - dilation * (kernel - 1 ) - 1 ) / stride + 1 ;
66+ };
67+ size_t H_out = calc_out (H_in, pad_h, dil_h, kH , stride_h);
68+ size_t W_out = calc_out (W_in, pad_w, dil_w, kW , stride_w);
69+ if ((int64_t )H_out <= 0 || (int64_t )W_out <= 0 ) {
70+ throw std::runtime_error (" Invalid conv2d output shape (negative or zero)" );
71+ }
72+ Shape out_shape = {N, C_out, H_out, W_out};
73+
74+ auto output = Tensor::empty (out_shape, input->dtype (), input->device ());
4475 conv2d_ (output, input, weight, bias, pads, strides, dilations);
4576 return output;
4677}
0 commit comments