Skip to content

Commit cc60b9b

Browse files
committed
feat(burn): wire SIMD sin/cos via hpc::vml (eliminate f64 roundtrip)
float_sin → ndarray::hpc::vml::vssin (F32x16 direct, no f64 conversion) float_cos → ndarray::hpc::vml::vscos (F32x16 direct, no f64 conversion) Original burn-ndarray: cast f32→f64, compute sin/cos, cast f64→f32. Our path: operate directly on f32 via SIMD polynomial approximation. Total SIMD-wired ops: exp, log, sqrt, abs, sin, cos, sigmoid (7 ops). 30 tests passing. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent 8d3f6bc commit cc60b9b

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

crates/burn/src/ops/tensor.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,11 @@ where
561561
}
562562

563563
fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
564+
#[cfg(feature = "simd")]
565+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vscos) {
566+
Ok(result) => return result,
567+
Err(t) => t,
568+
};
564569
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
565570
array
566571
.mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
@@ -577,6 +582,11 @@ where
577582
}
578583

579584
fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
585+
#[cfg(feature = "simd")]
586+
let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vssin) {
587+
Ok(result) => return result,
588+
Err(t) => t,
589+
};
580590
execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
581591
array
582592
.mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())

0 commit comments

Comments
 (0)