Skip to content

Commit cbbeeee

Browse files
heiherAlexhuszagh
authored andcommitted
LoongArch64 FP16 hardware support
LoongArch is a RISC instruction set architecture and currently a Tier-2 (with host-tools) target [^1] in the Rust upstream community. This patch introduces FP16 conversion functions based on the LoongArch SIMD extension to improve performance. Benchmarks: ``` HalfFloatSliceExt::convert_from_f32_slice/constants time: [10.816 ns 10.823 ns 10.831 ns] change: [-63.769% -63.728% -63.693%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_from_f32_slice/large time: [137.68 ns 137.77 ns 137.88 ns] change: [-94.847% -94.841% -94.834%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_from_f64_slice/constants time: [12.656 ns 12.669 ns 12.684 ns] change: [-78.455% -78.418% -78.367%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_from_f64_slice/large time: [544.15 ns 544.49 ns 544.91 ns] change: [-89.799% -89.791% -89.781%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_to_f32_slice/constants time: [6.0412 ns 6.0442 ns 6.0482 ns] change: [-74.100% -74.068% -74.042%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_to_f32_slice/large time: [512.78 ns 513.08 ns 513.45 ns] change: [-77.628% -77.526% -77.422%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_to_f64_slice/constants time: [10.779 ns 10.784 ns 10.792 ns] change: [-49.028% -48.922% -48.813%] (p = 0.00 < 0.05) Performance has improved. HalfFloatSliceExt::convert_to_f64_slice/large time: [923.19 ns 923.77 ns 924.50 ns] change: [-80.876% -80.862% -80.849%] (p = 0.00 < 0.05) Performance has improved. ``` [^1]: https://doc.rust-lang.org/stable/rustc/platform-support/loongarch-linux.html
1 parent 0dda27c commit cbbeeee

4 files changed

Lines changed: 141 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ for specific CPU features which avoids the runtime overhead and works in a `no_s
2828
| ------------ | ------------------ | ----- |
2929
| `x86`/`x86_64` | `f16c` | This supports conversion to/from `f16` only (including vector SIMD) and does not support any `bf16` or arithmetic operations. |
3030
| `aarch64` | `fp16` | This supports all operations on `f16` only. |
31+
| `loongarch64` | `lsx` | This supports conversion to/from `f16` only (including vector SIMD) and does not support any `bf16` or arithmetic operations. |
3132

3233
### More Documentation
3334

src/binary16/arch.rs

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ mod x86;
99
#[cfg(target_arch = "aarch64")]
1010
mod aarch64;
1111

12+
#[cfg(target_arch = "loongarch64")]
13+
mod loongarch64;
14+
1215
macro_rules! convert_fn {
13-
(
14-
if x86_feature("f16c") { $f16c:expr }else if aarch64_feature("fp16") { $aarch64:expr }else { $fallback:expr }
15-
) => {
16+
(if x86_feature("f16c") { $f16c:expr }
17+
else if aarch64_feature("fp16") { $aarch64:expr }
18+
else if loongarch64_feature("lsx") { $loongarch64:expr }
19+
else { $fallback:expr }) => {
1620
cfg_if::cfg_if! {
1721
// Use intrinsics directly when a compile target or using no_std
1822
if #[cfg(all(
@@ -29,6 +33,12 @@ macro_rules! convert_fn {
2933
))] {
3034
$aarch64
3135
}
36+
else if #[cfg(all(
37+
target_arch = "loongarch64",
38+
target_feature = "lsx"
39+
))] {
40+
$loongarch64
41+
}
3242

3343
// Use CPU feature detection if using std
3444
else if #[cfg(all(
@@ -55,6 +65,17 @@ macro_rules! convert_fn {
5565
$fallback
5666
}
5767
}
68+
else if #[cfg(all(
69+
feature = "std",
70+
target_arch = "loongarch64",
71+
))] {
72+
use std::arch::is_loongarch_feature_detected;
73+
if is_loongarch_feature_detected!("lsx") {
74+
$loongarch64
75+
} else {
76+
$fallback
77+
}
78+
}
5879

5980
// Fallback to software
6081
else {
@@ -71,6 +92,8 @@ pub(crate) fn f32_to_f16(f: f32) -> u16 {
7192
unsafe { x86::f32_to_f16_x86_f16c(f) }
7293
} else if aarch64_feature("fp16") {
7394
unsafe { aarch64::f32_to_f16_fp16(f) }
95+
} else if loongarch64_feature("lsx") {
96+
unsafe { loongarch64::f32_to_f16_lsx(f) }
7497
} else {
7598
f32_to_f16_fallback(f)
7699
}
@@ -84,6 +107,8 @@ pub(crate) fn f64_to_f16(f: f64) -> u16 {
84107
unsafe { x86::f64_to_f16_x86_f16c(f) }
85108
} else if aarch64_feature("fp16") {
86109
unsafe { aarch64::f64_to_f16_fp16(f) }
110+
} else if loongarch64_feature("lsx") {
111+
f64_to_f16_fallback(f)
87112
} else {
88113
f64_to_f16_fallback(f)
89114
}
@@ -97,6 +122,8 @@ pub(crate) fn f16_to_f32(i: u16) -> f32 {
97122
unsafe { x86::f16_to_f32_x86_f16c(i) }
98123
} else if aarch64_feature("fp16") {
99124
unsafe { aarch64::f16_to_f32_fp16(i) }
125+
} else if loongarch64_feature("lsx") {
126+
unsafe { loongarch64::f16_to_f32_lsx(i) }
100127
} else {
101128
f16_to_f32_fallback(i)
102129
}
@@ -110,6 +137,8 @@ pub(crate) fn f16_to_f64(i: u16) -> f64 {
110137
unsafe { x86::f16_to_f64_x86_f16c(i) }
111138
} else if aarch64_feature("fp16") {
112139
unsafe { aarch64::f16_to_f64_fp16(i) }
140+
} else if loongarch64_feature("lsx") {
141+
unsafe { loongarch64::f16_to_f32_lsx(i) as f64 }
113142
} else {
114143
f16_to_f64_fallback(i)
115144
}
@@ -123,6 +152,8 @@ pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] {
123152
unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
124153
} else if aarch64_feature("fp16") {
125154
unsafe { aarch64::f32x4_to_f16x4_fp16(f) }
155+
} else if loongarch64_feature("lsx") {
156+
unsafe { loongarch64::f32x4_to_f16x4_lsx(f) }
126157
} else {
127158
f32x4_to_f16x4_fallback(f)
128159
}
@@ -136,6 +167,8 @@ pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] {
136167
unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
137168
} else if aarch64_feature("fp16") {
138169
unsafe { aarch64::f16x4_to_f32x4_fp16(i) }
170+
} else if loongarch64_feature("lsx") {
171+
unsafe { loongarch64::f16x4_to_f32x4_lsx(i) }
139172
} else {
140173
f16x4_to_f32x4_fallback(i)
141174
}
@@ -149,6 +182,8 @@ pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] {
149182
unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
150183
} else if aarch64_feature("fp16") {
151184
unsafe { aarch64::f64x4_to_f16x4_fp16(f) }
185+
} else if loongarch64_feature("lsx") {
186+
unsafe { loongarch64::f64x4_to_f16x4_lsx(f) }
152187
} else {
153188
f64x4_to_f16x4_fallback(f)
154189
}
@@ -162,6 +197,8 @@ pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] {
162197
unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
163198
} else if aarch64_feature("fp16") {
164199
unsafe { aarch64::f16x4_to_f64x4_fp16(i) }
200+
} else if loongarch64_feature("lsx") {
201+
unsafe { loongarch64::f16x4_to_f64x4_lsx(i) }
165202
} else {
166203
f16x4_to_f64x4_fallback(i)
167204
}
@@ -180,6 +217,13 @@ pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] {
180217
aarch64::f32x4_to_f16x4_fp16);
181218
result
182219
}
220+
} else if loongarch64_feature("lsx") {
221+
{
222+
let mut result = [0u16; 8];
223+
convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
224+
loongarch64::f32x4_to_f16x4_lsx);
225+
result
226+
}
183227
} else {
184228
f32x8_to_f16x8_fallback(f)
185229
}
@@ -198,6 +242,13 @@ pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] {
198242
aarch64::f16x4_to_f32x4_fp16);
199243
result
200244
}
245+
} else if loongarch64_feature("lsx") {
246+
{
247+
let mut result = [0f32; 8];
248+
convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
249+
loongarch64::f16x4_to_f32x4_lsx);
250+
result
251+
}
201252
} else {
202253
f16x8_to_f32x8_fallback(i)
203254
}
@@ -216,6 +267,13 @@ pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] {
216267
aarch64::f64x4_to_f16x4_fp16);
217268
result
218269
}
270+
} else if loongarch64_feature("lsx") {
271+
{
272+
let mut result = [0u16; 8];
273+
convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
274+
loongarch64::f64x4_to_f16x4_lsx);
275+
result
276+
}
219277
} else {
220278
f64x8_to_f16x8_fallback(f)
221279
}
@@ -234,6 +292,13 @@ pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] {
234292
aarch64::f16x4_to_f64x4_fp16);
235293
result
236294
}
295+
} else if loongarch64_feature("lsx") {
296+
{
297+
let mut result = [0f64; 8];
298+
convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
299+
loongarch64::f16x4_to_f64x4_lsx);
300+
result
301+
}
237302
} else {
238303
f16x8_to_f64x8_fallback(i)
239304
}
@@ -248,6 +313,8 @@ pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) {
248313
x86::f32x4_to_f16x4_x86_f16c)
249314
} else if aarch64_feature("fp16") {
250315
convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16)
316+
} else if loongarch64_feature("lsx") {
317+
convert_chunked_slice_4(src, dst, loongarch64::f32x4_to_f16x4_lsx)
251318
} else {
252319
slice_fallback(src, dst, f32_to_f16_fallback)
253320
}
@@ -262,6 +329,8 @@ pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) {
262329
x86::f16x4_to_f32x4_x86_f16c)
263330
} else if aarch64_feature("fp16") {
264331
convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16)
332+
} else if loongarch64_feature("lsx") {
333+
convert_chunked_slice_4(src, dst, loongarch64::f16x4_to_f32x4_lsx)
265334
} else {
266335
slice_fallback(src, dst, f16_to_f32_fallback)
267336
}
@@ -276,6 +345,8 @@ pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) {
276345
x86::f64x4_to_f16x4_x86_f16c)
277346
} else if aarch64_feature("fp16") {
278347
convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16)
348+
} else if loongarch64_feature("lsx") {
349+
convert_chunked_slice_4(src, dst, loongarch64::f64x4_to_f16x4_lsx)
279350
} else {
280351
slice_fallback(src, dst, f64_to_f16_fallback)
281352
}
@@ -290,6 +361,8 @@ pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) {
290361
x86::f16x4_to_f64x4_x86_f16c)
291362
} else if aarch64_feature("fp16") {
292363
convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16)
364+
} else if loongarch64_feature("lsx") {
365+
convert_chunked_slice_4(src, dst, loongarch64::f16x4_to_f64x4_lsx)
293366
} else {
294367
slice_fallback(src, dst, f16_to_f64_fallback)
295368
}

src/binary16/arch/loongarch64.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use core::{mem::MaybeUninit, ptr};
2+
3+
#[cfg(target_arch = "loongarch64")]
4+
use core::arch::loongarch64::{lsx_vfcvt_h_s, lsx_vfcvtl_s_h, m128, m128i};
5+
6+
/////////////// loongarch64 lsx/lasx ////////////////
7+
8+
#[target_feature(enable = "lsx")]
9+
#[inline]
10+
pub(super) unsafe fn f16_to_f32_lsx(i: u16) -> f32 {
11+
let mut vec = MaybeUninit::<m128i>::zeroed();
12+
vec.as_mut_ptr().cast::<u16>().write(i);
13+
let retval = lsx_vfcvtl_s_h(vec.assume_init());
14+
*(&retval as *const m128).cast()
15+
}
16+
17+
#[target_feature(enable = "lsx")]
18+
#[inline]
19+
pub(super) unsafe fn f32_to_f16_lsx(f: f32) -> u16 {
20+
let mut vec = MaybeUninit::<m128>::zeroed();
21+
vec.as_mut_ptr().cast::<f32>().write(f);
22+
let retval = lsx_vfcvt_h_s(vec.assume_init(), vec.assume_init());
23+
*(&retval as *const m128i).cast()
24+
}
25+
26+
#[target_feature(enable = "lsx")]
27+
#[inline]
28+
pub(super) unsafe fn f16x4_to_f32x4_lsx(v: &[u16; 4]) -> [f32; 4] {
29+
let mut vec = MaybeUninit::<m128i>::zeroed();
30+
ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
31+
let retval = lsx_vfcvtl_s_h(vec.assume_init());
32+
*(&retval as *const m128).cast()
33+
}
34+
35+
#[target_feature(enable = "lsx")]
36+
#[inline]
37+
pub(super) unsafe fn f32x4_to_f16x4_lsx(v: &[f32; 4]) -> [u16; 4] {
38+
let mut vec = MaybeUninit::<m128>::uninit();
39+
ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
40+
let retval = lsx_vfcvt_h_s(vec.assume_init(), vec.assume_init());
41+
*(&retval as *const m128i).cast()
42+
}
43+
44+
#[target_feature(enable = "lsx")]
45+
#[inline]
46+
pub(super) unsafe fn f16x4_to_f64x4_lsx(v: &[u16; 4]) -> [f64; 4] {
47+
let array = f16x4_to_f32x4_lsx(v);
48+
// Let compiler vectorize this regular cast for now.
49+
[
50+
array[0] as f64,
51+
array[1] as f64,
52+
array[2] as f64,
53+
array[3] as f64,
54+
]
55+
}
56+
57+
#[target_feature(enable = "lsx")]
58+
#[inline]
59+
pub(super) unsafe fn f64x4_to_f16x4_lsx(v: &[f64; 4]) -> [u16; 4] {
60+
// Let compiler vectorize this regular cast for now.
61+
let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
62+
f32x4_to_f16x4_lsx(&v)
63+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
//! | ------------ | ------------------ | ----- |
5050
//! | `x86`/`x86_64` | `f16c` | This supports conversion to/from [`struct@f16`] only (including vector SIMD) and does not support any [`struct@bf16`] or arithmetic operations. |
5151
//! | `aarch64` | `fp16` | This supports all operations on [`struct@f16`] only. |
52+
//! | `loongarch64` | `lsx` | This supports conversion to/from [`struct@f16`] only (including vector SIMD) and does not support any [`struct@bf16`] or arithmetic operations. |
5253
//!
5354
//! # Cargo Features
5455
//!

0 commit comments

Comments
 (0)