Skip to content

Commit 984d50c

Browse files
committed
feat(burn): wire ndarray hpc::vml SIMD into float_exp/log/sqrt/abs
First augmentation of the burn backend with our crate::simd F32x16 path. For contiguous f32 tensors, these operations now route through ndarray::hpc::vml which uses crate::simd::F32x16 (AVX-512/AVX2 via LazyLock dispatch). Non-f32 or non-contiguous tensors fall through to the original scalar mapv_into path. float_exp → ndarray::hpc::vml::vsexp (F32x16 polynomial approx) float_log → ndarray::hpc::vml::vsln (F32x16 polynomial approx) float_sqrt → ndarray::hpc::vml::vssqrt (F32x16 hardware sqrt) float_abs → ndarray::hpc::vml::vsabs (F32x16 bitmask) try_vml_unary() helper: - Checks tensor is F32 variant + contiguous layout - Extracts &[f32] slice (zero-copy read) - Calls VML function → Vec<f32> output - Wraps into NdArrayTensor::F32(Owned) - Falls through to scalar on non-f32/non-contiguous 30 tests passing. Zero regressions. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent 129a959 commit 984d50c

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

crates/burn/src/ops/tensor.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,34 @@ use num_traits::Float;
3232

3333
use libm::erf;
3434

35+
/// Try to accelerate a unary f32 operation via ndarray's hpc::vml (F32x16 SIMD).
36+
///
37+
/// VML signature: `fn(input: &[f32], output: &mut [f32])`.
38+
/// Uses crate::simd::F32x16 internally. Consumer never sees hardware details.
39+
#[cfg(feature = "simd")]
40+
fn try_vml_unary(
41+
tensor: NdArrayTensor,
42+
vml_fn: fn(&[f32], &mut [f32]),
43+
) -> Result<NdArrayTensor, NdArrayTensor> {
44+
if let NdArrayTensor::F32(storage) = tensor {
45+
let shared = storage.into_shared();
46+
if shared.is_standard_layout() {
47+
if let Some(input) = shared.as_slice() {
48+
let mut output = vec![0.0f32; input.len()];
49+
vml_fn(input, &mut output);
50+
let shape = shared.shape().to_vec();
51+
let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output)
52+
.expect("vml output shape mismatch");
53+
return Ok(NdArrayTensor::F32(
54+
crate::NdArrayStorage::Owned(array.into_shared()),
55+
));
56+
}
57+
}
58+
return Err(NdArrayTensor::F32(crate::NdArrayStorage::Owned(shared)));
59+
}
60+
Err(tensor)
61+
}
62+
3563
#[cfg(feature = "std")]
3664
#[allow(dead_code)]
3765
fn round_ties_even_wrapper(x: f64) -> f64 {
@@ -446,12 +474,24 @@ where
446474
}
447475

448476
fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
477+
// Fast path: contiguous f32 → ndarray::hpc::vml::vsexp (F32x16 SIMD).
478+
// Falls through to scalar mapv_into for non-f32 or non-contiguous.
479+
#[cfg(feature = "simd")]
480+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsexp) {
481+
Ok(result) => return result,
482+
Err(t) => t,
483+
};
449484
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
450485
array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
451486
})
452487
}
453488

454489
fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
490+
#[cfg(feature = "simd")]
491+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsln) {
492+
Ok(result) => return result,
493+
Err(t) => t,
494+
};
455495
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
456496
array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
457497
})
@@ -499,12 +539,22 @@ where
499539
}
500540

501541
fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
542+
#[cfg(feature = "simd")]
543+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vssqrt) {
544+
Ok(result) => return result,
545+
Err(t) => t,
546+
};
502547
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
503548
array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
504549
})
505550
}
506551

507552
fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
553+
#[cfg(feature = "simd")]
554+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsabs) {
555+
Ok(result) => return result,
556+
Err(t) => t,
557+
};
508558
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
509559
NdArrayMathOps::abs(array)
510560
})

0 commit comments

Comments
 (0)