Skip to content

Commit 43fb39f

Browse files
committed
Add avx512 pack* family of instructions
1 parent decd472 commit 43fb39f

2 files changed

Lines changed: 161 additions & 3 deletions

File tree

src/tools/miri/src/shims/x86/avx512.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
33
use rustc_span::Symbol;
44
use rustc_target::callconv::FnAbi;
55

6-
use super::{permute, pmaddbw, pmaddwd, psadbw, pshufb};
6+
use super::{packssdw, packsswb, packusdw, packuswb, permute, pmaddbw, pmaddwd, psadbw, pshufb};
77
use crate::*;
88

99
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -130,6 +130,38 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
130130

131131
vpdpbusd(this, src, a, b, dest)?;
132132
}
133+
// Used to implement the _mm512_packs_epi16 function
134+
"packsswb.512" => {
135+
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
136+
137+
let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
138+
139+
packsswb(this, a, b, dest)?;
140+
}
141+
// Used to implement the _mm512_packus_epi16 function
142+
"packuswb.512" => {
143+
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
144+
145+
let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
146+
147+
packuswb(this, a, b, dest)?;
148+
}
149+
// Used to implement the _mm512_packs_epi32 function
150+
"packssdw.512" => {
151+
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
152+
153+
let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
154+
155+
packssdw(this, a, b, dest)?;
156+
}
157+
// Used to implement the _mm512_packus_epi32 function
158+
"packusdw.512" => {
159+
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
160+
161+
let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
162+
163+
packusdw(this, a, b, dest)?;
164+
}
133165
_ => return interp_ok(EmulateItemResult::NotSupported),
134166
}
135167
interp_ok(EmulateItemResult::NeedsReturn)

src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// We're testing x86 target specific features
22
//@only-target: x86_64 i686
3-
//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq,+avx512vnni
3+
//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bw,+avx512bitalg,+avx512vpopcntdq,+avx512vnni
44

55
#[cfg(target_arch = "x86")]
66
use std::arch::x86::*;
@@ -11,12 +11,14 @@ use std::mem::transmute;
1111
fn main() {
1212
assert!(is_x86_feature_detected!("avx512f"));
1313
assert!(is_x86_feature_detected!("avx512vl"));
14+
assert!(is_x86_feature_detected!("avx512bw"));
1415
assert!(is_x86_feature_detected!("avx512bitalg"));
1516
assert!(is_x86_feature_detected!("avx512vpopcntdq"));
1617
assert!(is_x86_feature_detected!("avx512vnni"));
1718

1819
unsafe {
1920
test_avx512();
21+
test_avx512bw();
2022
test_avx512bitalg();
2123
test_avx512vpopcntdq();
2224
test_avx512ternarylogic();
@@ -579,9 +581,133 @@ unsafe fn test_avx512vnni() {
579581
test_mm512_dpbusd_epi32();
580582
}
581583

584+
#[target_feature(enable = "avx512bw")]
585+
unsafe fn test_avx512bw() {
586+
#[target_feature(enable = "avx512bw")]
587+
unsafe fn test_mm512_packs_epi16() {
588+
let a = _mm512_set1_epi16(120);
589+
590+
// Because `packs` instructions do signed saturation, we expect
591+
// that any value over `i8::MAX` will be saturated to `i8::MAX`, and any value
592+
// less than `i8::MIN` will also be saturated to `i8::MIN`.
593+
let b = _mm512_set_epi16(
594+
200, 200, 200, 200, 200, 200, 200, 200, -200, -200, -200, -200, -200, -200, -200, -200,
595+
200, 200, 200, 200, 200, 200, 200, 200, -200, -200, -200, -200, -200, -200, -200, -200,
596+
);
597+
598+
// The pack* family of instructions in x86 operate in blocks
599+
// of 128-bit lanes, meaning the first 128-bit lane in `a` is converted and written
600+
// then the first 128-bit lane of `b`, followed by the second 128-bit lane in `a`, etc...
601+
// Because we are going from 16-bits to 8-bits our 128-bit block becomes 64-bits in
602+
// the output register.
603+
// This leaves us with 8x 8-bit values interleaved in the final register.
604+
#[rustfmt::skip]
605+
const DST: [i8; 64] = [
606+
120, 120, 120, 120, 120, 120, 120, 120,
607+
i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN,
608+
120, 120, 120, 120, 120, 120, 120, 120,
609+
i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX,
610+
120, 120, 120, 120, 120, 120, 120, 120,
611+
i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN,
612+
120, 120, 120, 120, 120, 120, 120, 120,
613+
i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX,
614+
];
615+
let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>());
616+
assert_eq_m512i(_mm512_packs_epi16(a, b), dst);
617+
}
618+
test_mm512_packs_epi16();
619+
620+
#[target_feature(enable = "avx512bw")]
621+
unsafe fn test_mm512_packus_epi16() {
622+
let a = _mm512_set1_epi16(120);
623+
624+
// Because `packus` instructions do unsigned saturation, we expect
625+
// that any value over `u8::MAX` will be saturated to `u8::MAX`, and any value
626+
// less than `u8::MIN` will also be saturated to `u8::MIN`.
627+
let b = _mm512_set_epi16(
628+
300, 300, 300, 300, 300, 300, 300, 300, -200, -200, -200, -200, -200, -200, -200, -200,
629+
300, 300, 300, 300, 300, 300, 300, 300, -200, -200, -200, -200, -200, -200, -200, -200,
630+
);
631+
632+
// See `test_mm512_packs_epi16` for an explanation of the output structure.
633+
#[rustfmt::skip]
634+
const DST: [u8; 64] = [
635+
120, 120, 120, 120, 120, 120, 120, 120,
636+
u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN,
637+
120, 120, 120, 120, 120, 120, 120, 120,
638+
u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX,
639+
120, 120, 120, 120, 120, 120, 120, 120,
640+
u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN,
641+
120, 120, 120, 120, 120, 120, 120, 120,
642+
u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX,
643+
];
644+
let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>());
645+
assert_eq_m512i(_mm512_packus_epi16(a, b), dst);
646+
}
647+
test_mm512_packus_epi16();
648+
649+
#[target_feature(enable = "avx512bw")]
650+
unsafe fn test_mm512_packs_epi32() {
651+
let a = _mm512_set1_epi32(8_000);
652+
653+
// Because `packs` instructions do signed saturation, we expect
654+
// that any value over `i16::MAX` will be saturated to `i16::MAX`, and any value
655+
// less than `i16::MIN` will also be saturated to `i16::MIN`.
656+
let b = _mm512_set_epi32(
657+
50_000, 50_000, 50_000, 50_000, -50_000, -50_000, -50_000, -50_000, 50_000, 50_000,
658+
50_000, 50_000, -50_000, -50_000, -50_000, -50_000,
659+
);
660+
661+
// See `test_mm512_packs_epi16` for an explanation of the output structure.
662+
#[rustfmt::skip]
663+
const DST: [i16; 32] = [
664+
8_000, 8_000, 8_000, 8_000,
665+
i16::MIN, i16::MIN, i16::MIN, i16::MIN,
666+
8_000, 8_000, 8_000, 8_000,
667+
i16::MAX, i16::MAX, i16::MAX, i16::MAX,
668+
8_000, 8_000, 8_000, 8_000,
669+
i16::MIN, i16::MIN, i16::MIN, i16::MIN,
670+
8_000, 8_000, 8_000, 8_000,
671+
i16::MAX, i16::MAX, i16::MAX, i16::MAX,
672+
];
673+
let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>());
674+
assert_eq_m512i(_mm512_packs_epi32(a, b), dst);
675+
}
676+
test_mm512_packs_epi32();
677+
678+
#[target_feature(enable = "avx512bw")]
679+
unsafe fn test_mm512_packus_epi32() {
680+
let a = _mm512_set1_epi32(8_000);
681+
682+
// Because `packus` instructions do unsigned saturation, we expect
683+
// that any value over `u16::MAX` will be saturated to `u16::MAX`, and any value
684+
// less than `u16::MIN` will also be saturated to `u16::MIN`.
685+
let b = _mm512_set_epi32(
686+
80_000, 80_000, 80_000, 80_000, -50_000, -50_000, -50_000, -50_000, 80_000, 80_000,
687+
80_000, 80_000, -50_000, -50_000, -50_000, -50_000,
688+
);
689+
690+
// See `test_mm512_packs_epi16` for an explanation of the output structure.
691+
#[rustfmt::skip]
692+
const DST: [u16; 32] = [
693+
8_000, 8_000, 8_000, 8_000,
694+
u16::MIN, u16::MIN, u16::MIN, u16::MIN,
695+
8_000, 8_000, 8_000, 8_000,
696+
u16::MAX, u16::MAX, u16::MAX, u16::MAX,
697+
8_000, 8_000, 8_000, 8_000,
698+
u16::MIN, u16::MIN, u16::MIN, u16::MIN,
699+
8_000, 8_000, 8_000, 8_000,
700+
u16::MAX, u16::MAX, u16::MAX, u16::MAX,
701+
];
702+
let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>());
703+
assert_eq_m512i(_mm512_packus_epi32(a, b), dst);
704+
}
705+
test_mm512_packus_epi32();
706+
}
707+
582708
#[track_caller]
583709
unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) {
584-
assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b))
710+
assert_eq!(transmute::<_, [u16; 32]>(a), transmute::<_, [u16; 32]>(b))
585711
}
586712

587713
#[track_caller]

0 commit comments

Comments
 (0)