File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -2527,7 +2527,7 @@ def quantized_max_pool2d_nhwc_meta(
25272527def fully_connected_meta (
25282528 src : torch .Tensor ,
25292529 weight : torch .Tensor ,
2530- bias : torch .Tensor ,
2530+ bias : Optional [ torch .Tensor ] = None ,
25312531) -> torch .Tensor :
25322532 # src comes in shape [leading_dims, in_dim]
25332533 # weight comes in shape [out_dim, in_dim]
Original file line number Diff line number Diff line change @@ -633,10 +633,8 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor:
633633def fully_connected (
634634 input_tensor : torch .Tensor ,
635635 weight : torch .Tensor ,
636- bias : torch .Tensor ,
636+ bias : Optional [ torch .Tensor ] = None ,
637637) -> torch .Tensor :
638- if input_tensor .shape [0 ] != 1 :
639- raise ValueError ("Fully connected linear only supports batch size of 1" )
640638 return F .linear (input_tensor , weight , bias )
641639
642640
Original file line number Diff line number Diff line change @@ -27,7 +27,8 @@ void linear(
2727 Tensor& output) {
2828 const float * __restrict__ input_data = input.const_data_ptr <float >();
2929 const float * __restrict__ weight_data = weight.const_data_ptr <float >();
30- const float * __restrict__ bias_data = bias.value ().const_data_ptr <float >();
30+ const float * __restrict__ bias_data =
31+ bias.has_value () ? bias.value ().const_data_ptr <float >() : nullptr ;
3132 float * __restrict__ output_data = output.mutable_data_ptr <float >();
3233
3334 // input comes in shape [batch_size, in_dim]
@@ -43,7 +44,7 @@ void linear(
4344
4445 for (int i = 0 ; i < leading_dims; ++i) {
4546 for (int j = 0 ; j < M; ++j) {
46- float sum = bias_data[j];
47+ float sum = bias_data != nullptr ? bias_data [j] : 0 . 0f ;
4748 for (int k = 0 ; k < N; ++k) {
4849 sum += input_data[i * N + k] * weight_data[j * N + k];
4950 }
You can’t perform that action at this time.
0 commit comments