Skip to content

Commit a1d8229

Browse files
hsharma35facebook-github-bot
authored andcommitted
Fix optional bias and batch handling in cadence::fully_connected (#19194)
Summary: Fixes two bugs in the generic and HiFi cadence::fully_connected implementations. First, the optional bias was dereferenced without a has_value() guard, causing a crash for bias-free inputs. Second, only the first input row was computed because the batch loop was missing; a loop over leading_dims (the product of all non-channel input dimensions) is now added to correctly process batched and multi-sequence inputs. Reviewed By: mcremon-meta Differential Revision: D102821213
1 parent 2406e5a commit a1d8229

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2527,7 +2527,7 @@ def quantized_max_pool2d_nhwc_meta(
25272527
def fully_connected_meta(
25282528
src: torch.Tensor,
25292529
weight: torch.Tensor,
2530-
bias: torch.Tensor,
2530+
bias: Optional[torch.Tensor] = None,
25312531
) -> torch.Tensor:
25322532
# src comes in shape [leading_dims, in_dim]
25332533
# weight comes in shape [out_dim, in_dim]

backends/cadence/aot/ref_implementations.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,10 +633,8 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor:
633633
def fully_connected(
634634
input_tensor: torch.Tensor,
635635
weight: torch.Tensor,
636-
bias: torch.Tensor,
636+
bias: Optional[torch.Tensor] = None,
637637
) -> torch.Tensor:
638-
if input_tensor.shape[0] != 1:
639-
raise ValueError("Fully connected linear only supports batch size of 1")
640638
return F.linear(input_tensor, weight, bias)
641639

642640

backends/cadence/generic/operators/op_fully_connected.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ void linear(
2727
Tensor& output) {
2828
const float* __restrict__ input_data = input.const_data_ptr<float>();
2929
const float* __restrict__ weight_data = weight.const_data_ptr<float>();
30-
const float* __restrict__ bias_data = bias.value().const_data_ptr<float>();
30+
const float* __restrict__ bias_data =
31+
bias.has_value() ? bias.value().const_data_ptr<float>() : nullptr;
3132
float* __restrict__ output_data = output.mutable_data_ptr<float>();
3233

3334
// input comes in shape [batch_size, in_dim]
@@ -43,7 +44,7 @@ void linear(
4344

4445
for (int i = 0; i < leading_dims; ++i) {
4546
for (int j = 0; j < M; ++j) {
46-
float sum = bias_data[j];
47+
float sum = bias_data != nullptr ? bias_data[j] : 0.0f;
4748
for (int k = 0; k < N; ++k) {
4849
sum += input_data[i * N + k] * weight_data[j * N + k];
4950
}

0 commit comments

Comments
 (0)