Skip to content

Commit ab6b50b

Browse files
committed
cg_llvm: sve_cast intrinsic
Abstract over the existing `simd_cast` intrinsic to implement a new `sve_cast` intrinsic - this is better than allowing scalable vectors to be used with all of the generic `simd_*` intrinsics.
1 parent cecef9a commit ab6b50b

5 files changed

Lines changed: 203 additions & 89 deletions

File tree

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 114 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,27 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
581581
self.pointercast(val, self.type_ptr())
582582
}
583583

584+
sym::sve_cast => {
585+
let Some((in_cnt, in_elem, in_num_vecs)) =
586+
args[0].layout.ty.scalable_vector_parts(self.cx.tcx)
587+
else {
588+
bug!("input parameter to `sve_cast` was not scalable vector");
589+
};
590+
let out_layout = self.layout_of(fn_args.type_at(1));
591+
let Some((out_cnt, out_elem, out_num_vecs)) =
592+
out_layout.ty.scalable_vector_parts(self.cx.tcx)
593+
else {
594+
bug!("output parameter to `sve_cast` was not scalable vector");
595+
};
596+
assert_eq!(in_cnt, out_cnt);
597+
assert_eq!(in_num_vecs, out_num_vecs);
598+
let out_llty = self.backend_type(out_layout);
599+
match simd_cast(self, sym::simd_cast, args, out_llty, in_elem, out_elem) {
600+
Some(val) => val,
601+
_ => bug!("could not cast scalable vectors"),
602+
}
603+
}
604+
584605
sym::sve_tuple_create2 => {
585606
assert_matches!(
586607
self.layout_of(fn_args.type_at(0)).backend_repr,
@@ -2738,96 +2759,17 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
27382759
out_len
27392760
}
27402761
);
2741-
// casting cares about nominal type, not just structural type
2742-
if in_elem == out_elem {
2743-
return Ok(args[0].immediate());
2744-
}
2745-
2746-
#[derive(Copy, Clone)]
2747-
enum Sign {
2748-
Unsigned,
2749-
Signed,
2750-
}
2751-
use Sign::*;
2752-
2753-
enum Style {
2754-
Float,
2755-
Int(Sign),
2756-
Unsupported,
2757-
}
2758-
2759-
let (in_style, in_width) = match in_elem.kind() {
2760-
// vectors of pointer-sized integers should've been
2761-
// disallowed before here, so this unwrap is safe.
2762-
ty::Int(i) => (
2763-
Style::Int(Signed),
2764-
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2765-
),
2766-
ty::Uint(u) => (
2767-
Style::Int(Unsigned),
2768-
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2769-
),
2770-
ty::Float(f) => (Style::Float, f.bit_width()),
2771-
_ => (Style::Unsupported, 0),
2772-
};
2773-
let (out_style, out_width) = match out_elem.kind() {
2774-
ty::Int(i) => (
2775-
Style::Int(Signed),
2776-
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2777-
),
2778-
ty::Uint(u) => (
2779-
Style::Int(Unsigned),
2780-
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2781-
),
2782-
ty::Float(f) => (Style::Float, f.bit_width()),
2783-
_ => (Style::Unsupported, 0),
2784-
};
2785-
2786-
match (in_style, out_style) {
2787-
(Style::Int(sign), Style::Int(_)) => {
2788-
return Ok(match in_width.cmp(&out_width) {
2789-
Ordering::Greater => bx.trunc(args[0].immediate(), llret_ty),
2790-
Ordering::Equal => args[0].immediate(),
2791-
Ordering::Less => match sign {
2792-
Sign::Signed => bx.sext(args[0].immediate(), llret_ty),
2793-
Sign::Unsigned => bx.zext(args[0].immediate(), llret_ty),
2794-
},
2795-
});
2796-
}
2797-
(Style::Int(Sign::Signed), Style::Float) => {
2798-
return Ok(bx.sitofp(args[0].immediate(), llret_ty));
2799-
}
2800-
(Style::Int(Sign::Unsigned), Style::Float) => {
2801-
return Ok(bx.uitofp(args[0].immediate(), llret_ty));
2802-
}
2803-
(Style::Float, Style::Int(sign)) => {
2804-
return Ok(match (sign, name == sym::simd_as) {
2805-
(Sign::Unsigned, false) => bx.fptoui(args[0].immediate(), llret_ty),
2806-
(Sign::Signed, false) => bx.fptosi(args[0].immediate(), llret_ty),
2807-
(_, true) => bx.cast_float_to_int(
2808-
matches!(sign, Sign::Signed),
2809-
args[0].immediate(),
2810-
llret_ty,
2811-
),
2812-
});
2813-
}
2814-
(Style::Float, Style::Float) => {
2815-
return Ok(match in_width.cmp(&out_width) {
2816-
Ordering::Greater => bx.fptrunc(args[0].immediate(), llret_ty),
2817-
Ordering::Equal => args[0].immediate(),
2818-
Ordering::Less => bx.fpext(args[0].immediate(), llret_ty),
2819-
});
2820-
}
2821-
_ => { /* Unsupported. Fallthrough. */ }
2762+
match simd_cast(bx, name, args, llret_ty, in_elem, out_elem) {
2763+
Some(val) => return Ok(val),
2764+
None => return_error!(InvalidMonomorphization::UnsupportedCast {
2765+
span,
2766+
name,
2767+
in_ty,
2768+
in_elem,
2769+
ret_ty,
2770+
out_elem
2771+
}),
28222772
}
2823-
return_error!(InvalidMonomorphization::UnsupportedCast {
2824-
span,
2825-
name,
2826-
in_ty,
2827-
in_elem,
2828-
ret_ty,
2829-
out_elem
2830-
});
28312773
}
28322774
macro_rules! arith_binary {
28332775
($($name: ident: $($($p: ident),* => $call: ident),*;)*) => {
@@ -3001,3 +2943,86 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
30012943

30022944
span_bug!(span, "unknown SIMD intrinsic");
30032945
}
2946+
2947+
/// Implementation of `core::intrinsics::simd_cast`, re-used by `core::scalable::sve_cast`.
2948+
fn simd_cast<'ll, 'tcx>(
2949+
bx: &mut Builder<'_, 'll, 'tcx>,
2950+
name: Symbol,
2951+
args: &[OperandRef<'tcx, &'ll Value>],
2952+
llret_ty: &'ll Type,
2953+
in_elem: Ty<'tcx>,
2954+
out_elem: Ty<'tcx>,
2955+
) -> Option<&'ll Value> {
2956+
// Casting cares about nominal type, not just structural type
2957+
if in_elem == out_elem {
2958+
return Some(args[0].immediate());
2959+
}
2960+
2961+
#[derive(Copy, Clone)]
2962+
enum Sign {
2963+
Unsigned,
2964+
Signed,
2965+
}
2966+
use Sign::*;
2967+
2968+
enum Style {
2969+
Float,
2970+
Int(Sign),
2971+
Unsupported,
2972+
}
2973+
2974+
let (in_style, in_width) = match in_elem.kind() {
2975+
// vectors of pointer-sized integers should've been
2976+
// disallowed before here, so this unwrap is safe.
2977+
ty::Int(i) => (
2978+
Style::Int(Signed),
2979+
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2980+
),
2981+
ty::Uint(u) => (
2982+
Style::Int(Unsigned),
2983+
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2984+
),
2985+
ty::Float(f) => (Style::Float, f.bit_width()),
2986+
_ => (Style::Unsupported, 0),
2987+
};
2988+
let (out_style, out_width) = match out_elem.kind() {
2989+
ty::Int(i) => (
2990+
Style::Int(Signed),
2991+
i.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2992+
),
2993+
ty::Uint(u) => (
2994+
Style::Int(Unsigned),
2995+
u.normalize(bx.tcx().sess.target.pointer_width).bit_width().unwrap(),
2996+
),
2997+
ty::Float(f) => (Style::Float, f.bit_width()),
2998+
_ => (Style::Unsupported, 0),
2999+
};
3000+
3001+
match (in_style, out_style) {
3002+
(Style::Int(sign), Style::Int(_)) => Some(match in_width.cmp(&out_width) {
3003+
Ordering::Greater => bx.trunc(args[0].immediate(), llret_ty),
3004+
Ordering::Equal => args[0].immediate(),
3005+
Ordering::Less => match sign {
3006+
Sign::Signed => bx.sext(args[0].immediate(), llret_ty),
3007+
Sign::Unsigned => bx.zext(args[0].immediate(), llret_ty),
3008+
},
3009+
}),
3010+
(Style::Int(Sign::Signed), Style::Float) => Some(bx.sitofp(args[0].immediate(), llret_ty)),
3011+
(Style::Int(Sign::Unsigned), Style::Float) => {
3012+
Some(bx.uitofp(args[0].immediate(), llret_ty))
3013+
}
3014+
(Style::Float, Style::Int(sign)) => Some(match (sign, name == sym::simd_as) {
3015+
(Sign::Unsigned, false) => bx.fptoui(args[0].immediate(), llret_ty),
3016+
(Sign::Signed, false) => bx.fptosi(args[0].immediate(), llret_ty),
3017+
(_, true) => {
3018+
bx.cast_float_to_int(matches!(sign, Sign::Signed), args[0].immediate(), llret_ty)
3019+
}
3020+
}),
3021+
(Style::Float, Style::Float) => Some(match in_width.cmp(&out_width) {
3022+
Ordering::Greater => bx.fptrunc(args[0].immediate(), llret_ty),
3023+
Ordering::Equal => args[0].immediate(),
3024+
Ordering::Less => bx.fpext(args[0].immediate(), llret_ty),
3025+
}),
3026+
_ => None,
3027+
}
3028+
}

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ pub(crate) fn check_intrinsic_type(
785785
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
786786
sym::simd_shuffle_const_generic => (2, 1, vec![param(0), param(0)], param(1)),
787787

788+
sym::sve_cast => (2, 0, vec![param(0)], param(1)),
788789
sym::sve_tuple_create2 => (2, 0, vec![param(0), param(0)], param(1)),
789790
sym::sve_tuple_create3 => (2, 0, vec![param(0), param(0), param(0)], param(1)),
790791
sym::sve_tuple_create4 => (2, 0, vec![param(0), param(0), param(0), param(0)], param(1)),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,7 @@ symbols! {
19841984
suggestion,
19851985
super_let,
19861986
supertrait_item_shadowing,
1987+
sve_cast,
19871988
sve_tuple_create2,
19881989
sve_tuple_create3,
19891990
sve_tuple_create4,

library/core/src/intrinsics/simd/scalable.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22
//!
33
//! In this module, a "vector" is any `#[rustc_scalable_vector]`-annotated type.
44
5+
/// Numerically casts a vector, elementwise.
6+
///
7+
/// `T` and `U` must be vectors of integers or floats, and must have the same length.
8+
///
9+
/// When casting floats to integers, the result is truncated. Out-of-bounds result lead to UB.
10+
/// When casting integers to floats, the result is rounded.
11+
/// Otherwise, truncates or extends the value, maintaining the sign for signed integers.
12+
///
13+
/// # Safety
14+
/// Casting from integer types is always safe.
15+
/// Casting between two float types is also always safe.
16+
///
17+
/// Casting floats to integers truncates, following the same rules as `to_int_unchecked`.
18+
/// Specifically, each element must:
19+
/// * Not be `NaN`
20+
/// * Not be infinite
21+
/// * Be representable in the return type, after truncating off its fractional part
22+
#[cfg(target_arch = "aarch64")]
23+
#[rustc_intrinsic]
24+
#[rustc_nounwind]
25+
#[target_feature(enable = "sve")]
26+
pub unsafe fn sve_cast<T, U>(x: T) -> U;
27+
528
/// Create a tuple of two vectors.
629
///
730
/// `SVecTup` must be a scalable vector tuple (`#[rustc_scalable_vector]`) and `SVec` must be a
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//@ check-pass
2+
#![crate_type = "lib"]
3+
#![allow(incomplete_features, internal_features, improper_ctypes)]
4+
#![feature(abi_unadjusted, core_intrinsics, link_llvm_intrinsics, rustc_attrs)]
5+
6+
use std::intrinsics::simd::scalable::sve_cast;
7+
8+
#[derive(Copy, Clone)]
9+
#[rustc_scalable_vector(16)]
10+
#[allow(non_camel_case_types)]
11+
pub struct svbool_t(bool);
12+
13+
#[derive(Copy, Clone)]
14+
#[rustc_scalable_vector(2)]
15+
#[allow(non_camel_case_types)]
16+
pub struct svbool2_t(bool);
17+
18+
#[derive(Copy, Clone)]
19+
#[rustc_scalable_vector(2)]
20+
#[allow(non_camel_case_types)]
21+
pub struct svint64_t(i64);
22+
23+
#[derive(Copy, Clone)]
24+
#[rustc_scalable_vector(2)]
25+
#[allow(non_camel_case_types)]
26+
pub struct nxv2i16(i16);
27+
28+
pub trait SveInto<T>: Sized {
29+
unsafe fn sve_into(self) -> T;
30+
}
31+
32+
impl SveInto<svbool2_t> for svbool_t {
33+
#[target_feature(enable = "sve")]
34+
unsafe fn sve_into(self) -> svbool2_t {
35+
unsafe extern "C" {
36+
#[cfg_attr(
37+
target_arch = "aarch64",
38+
link_name = concat!("llvm.aarch64.sve.convert.from.svbool.nxv2i1")
39+
)]
40+
fn convert_from_svbool(b: svbool_t) -> svbool2_t;
41+
}
42+
unsafe { convert_from_svbool(self) }
43+
}
44+
}
45+
46+
#[target_feature(enable = "sve")]
47+
pub unsafe fn svld1sh_gather_s64offset_s64(
48+
pg: svbool_t,
49+
base: *const i16,
50+
offsets: svint64_t,
51+
) -> svint64_t {
52+
unsafe extern "unadjusted" {
53+
#[cfg_attr(
54+
target_arch = "aarch64",
55+
link_name = "llvm.aarch64.sve.ld1.gather.nxv2i16"
56+
)]
57+
fn _svld1sh_gather_s64offset_s64(
58+
pg: svbool2_t,
59+
base: *const i16,
60+
offsets: svint64_t,
61+
) -> nxv2i16;
62+
}
63+
sve_cast(_svld1sh_gather_s64offset_s64(pg.sve_into(), base, offsets))
64+
}

0 commit comments

Comments
 (0)