Skip to content

Commit 957320c

Browse files
committed
cg_llvm: sve_cast intrinsic
Abstract over the existing `simd_cast` intrinsic to implement a new `sve_cast` intrinsic - this is better than allowing scalable vectors to be used with all of the generic `simd_*` intrinsics.
1 parent a24ee03 commit 957320c

6 files changed

Lines changed: 204 additions & 90 deletions

File tree

compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::fmt::{self, Write};
33
use std::hash::{Hash, Hasher};
44
use std::path::PathBuf;
55
use std::sync::Arc;
6-
use std::{iter, ptr};
6+
use std::{assert_matches, iter, ptr};
77

88
use libc::{c_longlong, c_uint};
99
use rustc_abi::{Align, Layout, NumScalableVectors, Size};

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 114 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,27 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
606606
self.pointercast(val, self.type_ptr())
607607
}
608608

609+
sym::sve_cast => {
610+
let Some((in_cnt, in_elem, in_num_vecs)) =
611+
args[0].layout.ty.scalable_vector_parts(self.cx.tcx)
612+
else {
613+
bug!("input parameter to `sve_cast` was not scalable vector");
614+
};
615+
let out_layout = self.layout_of(fn_args.type_at(1));
616+
let Some((out_cnt, out_elem, out_num_vecs)) =
617+
out_layout.ty.scalable_vector_parts(self.cx.tcx)
618+
else {
619+
bug!("output parameter to `sve_cast` was not scalable vector");
620+
};
621+
assert_eq!(in_cnt, out_cnt);
622+
assert_eq!(in_num_vecs, out_num_vecs);
623+
let out_llty = self.backend_type(out_layout);
624+
match simd_cast(self, sym::simd_cast, args, out_llty, in_elem, out_elem) {
625+
Some(val) => val,
626+
_ => bug!("could not cast scalable vectors"),
627+
}
628+
}
629+
609630
sym::sve_tuple_create2 => {
610631
assert_matches!(
611632
self.layout_of(fn_args.type_at(0)).backend_repr,
@@ -2772,96 +2793,17 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
27722793
out_len
27732794
}
27742795
);
2775-
// casting cares about nominal type, not just structural type
2776-
if in_elem == out_elem {
2777-
return Ok(args[0].immediate());
2778-
}
2779-
2780-
#[derive(Copy, Clone)]
2781-
enum Sign {
2782-
Unsigned,
2783-
Signed,
2784-
}
2785-
use Sign::*;
2786-
2787-
enum Style {
2788-
Float,
2789-
Int(Sign),
2790-
Unsupported,
2791-
}
2792-
2793-
let (in_style, in_width) = match in_elem.kind() {
2794-
// vectors of pointer-sized integers should've been
2795-
// disallowed before here, so this unwrap is safe.
2796-
ty::Int(i) => (
2797-
Style::Int(Signed),
2798-
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2799-
),
2800-
ty::Uint(u) => (
2801-
Style::Int(Unsigned),
2802-
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2803-
),
2804-
ty::Float(f) => (Style::Float, f.bit_width()),
2805-
_ => (Style::Unsupported, 0),
2806-
};
2807-
let (out_style, out_width) = match out_elem.kind() {
2808-
ty::Int(i) => (
2809-
Style::Int(Signed),
2810-
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2811-
),
2812-
ty::Uint(u) => (
2813-
Style::Int(Unsigned),
2814-
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2815-
),
2816-
ty::Float(f) => (Style::Float, f.bit_width()),
2817-
_ => (Style::Unsupported, 0),
2818-
};
2819-
2820-
match (in_style, out_style) {
2821-
(Style::Int(sign), Style::Int(_)) => {
2822-
return Ok(match in_width.cmp(&out_width) {
2823-
Ordering::Greater => bx.trunc(args[0].immediate(), llret_ty),
2824-
Ordering::Equal => args[0].immediate(),
2825-
Ordering::Less => match sign {
2826-
Sign::Signed => bx.sext(args[0].immediate(), llret_ty),
2827-
Sign::Unsigned => bx.zext(args[0].immediate(), llret_ty),
2828-
},
2829-
});
2830-
}
2831-
(Style::Int(Sign::Signed), Style::Float) => {
2832-
return Ok(bx.sitofp(args[0].immediate(), llret_ty));
2833-
}
2834-
(Style::Int(Sign::Unsigned), Style::Float) => {
2835-
return Ok(bx.uitofp(args[0].immediate(), llret_ty));
2836-
}
2837-
(Style::Float, Style::Int(sign)) => {
2838-
return Ok(match (sign, name == sym::simd_as) {
2839-
(Sign::Unsigned, false) => bx.fptoui(args[0].immediate(), llret_ty),
2840-
(Sign::Signed, false) => bx.fptosi(args[0].immediate(), llret_ty),
2841-
(_, true) => bx.cast_float_to_int(
2842-
matches!(sign, Sign::Signed),
2843-
args[0].immediate(),
2844-
llret_ty,
2845-
),
2846-
});
2847-
}
2848-
(Style::Float, Style::Float) => {
2849-
return Ok(match in_width.cmp(&out_width) {
2850-
Ordering::Greater => bx.fptrunc(args[0].immediate(), llret_ty),
2851-
Ordering::Equal => args[0].immediate(),
2852-
Ordering::Less => bx.fpext(args[0].immediate(), llret_ty),
2853-
});
2854-
}
2855-
_ => { /* Unsupported. Fallthrough. */ }
2796+
match simd_cast(bx, name, args, llret_ty, in_elem, out_elem) {
2797+
Some(val) => return Ok(val),
2798+
None => return_error!(InvalidMonomorphization::UnsupportedCast {
2799+
span,
2800+
name,
2801+
in_ty,
2802+
in_elem,
2803+
ret_ty,
2804+
out_elem
2805+
}),
28562806
}
2857-
return_error!(InvalidMonomorphization::UnsupportedCast {
2858-
span,
2859-
name,
2860-
in_ty,
2861-
in_elem,
2862-
ret_ty,
2863-
out_elem
2864-
});
28652807
}
28662808
macro_rules! arith_binary {
28672809
($($name: ident: $($($p: ident),* => $call: ident),*;)*) => {
@@ -3035,3 +2977,86 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
30352977

30362978
span_bug!(span, "unknown SIMD intrinsic");
30372979
}
2980+
2981+
/// Implementation of `core::intrinsics::simd_cast`, re-used by `core::scalable::sve_cast`.
2982+
fn simd_cast<'ll, 'tcx>(
2983+
bx: &mut Builder<'_, 'll, 'tcx>,
2984+
name: Symbol,
2985+
args: &[OperandRef<'tcx, &'ll Value>],
2986+
llret_ty: &'ll Type,
2987+
in_elem: Ty<'tcx>,
2988+
out_elem: Ty<'tcx>,
2989+
) -> Option<&'ll Value> {
2990+
// Casting cares about nominal type, not just structural type
2991+
if in_elem == out_elem {
2992+
return Some(args[0].immediate());
2993+
}
2994+
2995+
#[derive(Copy, Clone)]
2996+
enum Sign {
2997+
Unsigned,
2998+
Signed,
2999+
}
3000+
use Sign::*;
3001+
3002+
enum Style {
3003+
Float,
3004+
Int(Sign),
3005+
Unsupported,
3006+
}
3007+
3008+
let (in_style, in_width) = match in_elem.kind() {
3009+
// vectors of pointer-sized integers should've been
3010+
// disallowed before here, so this unwrap is safe.
3011+
ty::Int(i) => (
3012+
Style::Int(Signed),
3013+
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
3014+
),
3015+
ty::Uint(u) => (
3016+
Style::Int(Unsigned),
3017+
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
3018+
),
3019+
ty::Float(f) => (Style::Float, f.bit_width()),
3020+
_ => (Style::Unsupported, 0),
3021+
};
3022+
let (out_style, out_width) = match out_elem.kind() {
3023+
ty::Int(i) => (
3024+
Style::Int(Signed),
3025+
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
3026+
),
3027+
ty::Uint(u) => (
3028+
Style::Int(Unsigned),
3029+
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
3030+
),
3031+
ty::Float(f) => (Style::Float, f.bit_width()),
3032+
_ => (Style::Unsupported, 0),
3033+
};
3034+
3035+
match (in_style, out_style) {
3036+
(Style::Int(sign), Style::Int(_)) => Some(match in_width.cmp(&out_width) {
3037+
Ordering::Greater => bx.trunc(args[0].immediate(), llret_ty),
3038+
Ordering::Equal => args[0].immediate(),
3039+
Ordering::Less => match sign {
3040+
Sign::Signed => bx.sext(args[0].immediate(), llret_ty),
3041+
Sign::Unsigned => bx.zext(args[0].immediate(), llret_ty),
3042+
},
3043+
}),
3044+
(Style::Int(Sign::Signed), Style::Float) => Some(bx.sitofp(args[0].immediate(), llret_ty)),
3045+
(Style::Int(Sign::Unsigned), Style::Float) => {
3046+
Some(bx.uitofp(args[0].immediate(), llret_ty))
3047+
}
3048+
(Style::Float, Style::Int(sign)) => Some(match (sign, name == sym::simd_as) {
3049+
(Sign::Unsigned, false) => bx.fptoui(args[0].immediate(), llret_ty),
3050+
(Sign::Signed, false) => bx.fptosi(args[0].immediate(), llret_ty),
3051+
(_, true) => {
3052+
bx.cast_float_to_int(matches!(sign, Sign::Signed), args[0].immediate(), llret_ty)
3053+
}
3054+
}),
3055+
(Style::Float, Style::Float) => Some(match in_width.cmp(&out_width) {
3056+
Ordering::Greater => bx.fptrunc(args[0].immediate(), llret_ty),
3057+
Ordering::Equal => args[0].immediate(),
3058+
Ordering::Less => bx.fpext(args[0].immediate(), llret_ty),
3059+
}),
3060+
_ => None,
3061+
}
3062+
}

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ pub(crate) fn check_intrinsic_type(
783783
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
784784
sym::simd_shuffle_const_generic => (2, 1, vec![param(0), param(0)], param(1)),
785785

786+
sym::sve_cast => (2, 0, vec![param(0)], param(1)),
786787
sym::sve_tuple_create2 => (2, 0, vec![param(0), param(0)], param(1)),
787788
sym::sve_tuple_create3 => (2, 0, vec![param(0), param(0), param(0)], param(1)),
788789
sym::sve_tuple_create4 => (2, 0, vec![param(0), param(0), param(0), param(0)], param(1)),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,6 +1979,7 @@ symbols! {
19791979
suggestion,
19801980
super_let,
19811981
supertrait_item_shadowing,
1982+
sve_cast,
19821983
sve_tuple_create2,
19831984
sve_tuple_create3,
19841985
sve_tuple_create4,

library/core/src/intrinsics/simd/scalable.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@
22
//!
33
//! In this module, a "vector" is any `#[rustc_scalable_vector]`-annotated type.
44
5+
/// Numerically casts a vector, elementwise.
6+
///
7+
/// `T` and `U` must be vectors of integers or floats, and must have the same length.
8+
///
9+
/// When casting floats to integers, the result is truncated. Out-of-bounds result lead to UB.
10+
/// When casting integers to floats, the result is rounded.
11+
/// Otherwise, truncates or extends the value, maintaining the sign for signed integers.
12+
///
13+
/// # Safety
14+
/// Casting from integer types is always safe.
15+
/// Casting between two float types is also always safe.
16+
///
17+
/// Casting floats to integers truncates, following the same rules as `to_int_unchecked`.
18+
/// Specifically, each element must:
19+
/// * Not be `NaN`
20+
/// * Not be infinite
21+
/// * Be representable in the return type, after truncating off its fractional part
22+
#[cfg(target_arch = "aarch64")]
23+
#[rustc_intrinsic]
24+
#[rustc_nounwind]
25+
pub unsafe fn sve_cast<T, U>(x: T) -> U;
26+
527
/// Create a tuple of two vectors.
628
///
729
/// `SVecTup` must be a scalable vector tuple (`#[rustc_scalable_vector]`) and `SVec` must be a
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//@ check-pass
2+
//@ only-aarch64
3+
#![crate_type = "lib"]
4+
#![allow(incomplete_features, internal_features, improper_ctypes)]
5+
#![feature(abi_unadjusted, core_intrinsics, link_llvm_intrinsics, rustc_attrs)]
6+
7+
use std::intrinsics::simd::scalable::sve_cast;
8+
9+
#[derive(Copy, Clone)]
10+
#[rustc_scalable_vector(16)]
11+
#[allow(non_camel_case_types)]
12+
pub struct svbool_t(bool);
13+
14+
#[derive(Copy, Clone)]
15+
#[rustc_scalable_vector(2)]
16+
#[allow(non_camel_case_types)]
17+
pub struct svbool2_t(bool);
18+
19+
#[derive(Copy, Clone)]
20+
#[rustc_scalable_vector(2)]
21+
#[allow(non_camel_case_types)]
22+
pub struct svint64_t(i64);
23+
24+
#[derive(Copy, Clone)]
25+
#[rustc_scalable_vector(2)]
26+
#[allow(non_camel_case_types)]
27+
pub struct nxv2i16(i16);
28+
29+
pub trait SveInto<T>: Sized {
30+
unsafe fn sve_into(self) -> T;
31+
}
32+
33+
impl SveInto<svbool2_t> for svbool_t {
34+
#[target_feature(enable = "sve")]
35+
unsafe fn sve_into(self) -> svbool2_t {
36+
unsafe extern "C" {
37+
#[cfg_attr(
38+
target_arch = "aarch64",
39+
link_name = concat!("llvm.aarch64.sve.convert.from.svbool.nxv2i1")
40+
)]
41+
fn convert_from_svbool(b: svbool_t) -> svbool2_t;
42+
}
43+
unsafe { convert_from_svbool(self) }
44+
}
45+
}
46+
47+
#[target_feature(enable = "sve")]
48+
pub unsafe fn svld1sh_gather_s64offset_s64(
49+
pg: svbool_t,
50+
base: *const i16,
51+
offsets: svint64_t,
52+
) -> svint64_t {
53+
unsafe extern "unadjusted" {
54+
#[cfg_attr(
55+
target_arch = "aarch64",
56+
link_name = "llvm.aarch64.sve.ld1.gather.nxv2i16"
57+
)]
58+
fn _svld1sh_gather_s64offset_s64(
59+
pg: svbool2_t,
60+
base: *const i16,
61+
offsets: svint64_t,
62+
) -> nxv2i16;
63+
}
64+
sve_cast(_svld1sh_gather_s64offset_s64(pg.sve_into(), base, offsets))
65+
}

0 commit comments

Comments
 (0)