Skip to content

Commit 8d3f6bc

Browse files
committed
feat(burn): fused SIMD sigmoid via hpc::activations::sigmoid_f32
Override ActivationOps::sigmoid with fused F32x16 SIMD path. Default burn sigmoid: 6 separate ops (neg, exp, add, log, neg, exp) Our sigmoid: one fused pass: 1/(1+exp(-x)) via F32x16 polynomial For contiguous f32: use hpc::activations::sigmoid_f32 (F32x16 SIMD) For non-f32 or non-contiguous: decomposed via Backend float ops The fused path eliminates 5 intermediate tensor allocations and does the full sigmoid in a single pass over the data. 30 tests passing. Zero regressions. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent 984d50c commit 8d3f6bc

1 file changed

Lines changed: 28 additions & 1 deletion

File tree

crates/burn/src/ops/activation.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
NdArray, NdArrayTensor, SharedArray,
2+
NdArray, NdArrayStorage, NdArrayTensor, SharedArray,
33
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
44
execute_with_numeric_dtype,
55
ops::NdArrayMathOps,
@@ -15,4 +15,31 @@ where
1515
fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1616
execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem()))
1717
}
18+
19+
/// Sigmoid via ndarray::hpc::activations::sigmoid_f32 (fused F32x16 SIMD).
20+
///
21+
/// Default impl decomposes into 6 separate ops: neg, exp, add, log, neg, exp.
22+
/// Our version does `1 / (1 + exp(-x))` in one SIMD pass with F32x16.
23+
fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
24+
#[cfg(feature = "simd")]
25+
if let NdArrayTensor::F32(ref storage) = tensor {
26+
let view = storage.view();
27+
if view.is_standard_layout() {
28+
if let Some(input) = view.as_slice() {
29+
let mut output = alloc::vec![0.0f32; input.len()];
30+
ndarray::hpc::activations::sigmoid_f32(input, &mut output);
31+
let shape: alloc::vec::Vec<usize> = view.shape().to_vec();
32+
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
33+
.expect("sigmoid output shape mismatch");
34+
return NdArrayTensor::F32(NdArrayStorage::Owned(array.into_shared()));
35+
}
36+
}
37+
}
38+
// Fallback: decomposed sigmoid via Backend ops (non-f32 or non-contiguous).
39+
use burn_backend::ops::FloatTensorOps;
40+
let tensor_neg = Self::float_neg(tensor);
41+
let tensor_exp = Self::float_exp(tensor_neg);
42+
let tensor_add = Self::float_add_scalar(tensor_exp, 1.0.into());
43+
Self::float_recip(tensor_add)
44+
}
1845
}

0 commit comments

Comments
 (0)