Skip to content

Commit 9f93e49

Browse files
authored
winch: respect the enable_nan_canonicalization setting (#12939)
* winch: respect the enable_nan_canonicalization setting * add disas tests for NaN canonicalization * rename canonicalize_nan to maybe_canonicalize_nan * extract canonical NaN constants to shared masm module * remove canonicalize_nan_for_round, inline at call sites * implement SIMD NaN canonicalization for x64 * cargo fmt * add comment about branchless scalar canonicalization opportunity * add canonicalize-nan.wast to no-AVX skip list
1 parent 0096013 commit 9f93e49

9 files changed

Lines changed: 525 additions & 23 deletions

File tree

crates/test-util/src/wast.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,16 +620,6 @@ impl WastTest {
620620

621621
#[cfg(target_arch = "x86_64")]
622622
{
623-
let unsupported = [
624-
// externref/reference-types related
625-
// simd-related failures
626-
"misc_testsuite/simd/canonicalize-nan.wast",
627-
];
628-
629-
if unsupported.iter().any(|part| self.path.ends_with(part)) {
630-
return true;
631-
}
632-
633623
// SIMD on Winch requires AVX instructions.
634624
#[cfg(target_arch = "x86_64")]
635625
if !(std::is_x86_feature_detected!("avx") && std::is_x86_feature_detected!("avx2"))
@@ -640,6 +630,7 @@ impl WastTest {
640630
"misc_testsuite/int-to-float-splat.wast",
641631
"misc_testsuite/issue6562.wast",
642632
"misc_testsuite/simd/almost-extmul.wast",
633+
"misc_testsuite/simd/canonicalize-nan.wast",
643634
"misc_testsuite/simd/cvt-from-uint.wast",
644635
"misc_testsuite/simd/edge-of-memory.wast",
645636
"misc_testsuite/simd/issue_3327_bnot_lowering.wast",
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
;;! target = "x86_64"
2+
;;! test = "winch"
3+
;;! flags = "-Wnan-canonicalization"
4+
5+
(module
6+
(func (param f32 f32) (result f32)
7+
local.get 0
8+
local.get 1
9+
f32.add
10+
)
11+
)
12+
;; wasm[0]::function[0]:
13+
;; pushq %rbp
14+
;; movq %rsp, %rbp
15+
;; movq 8(%rdi), %r11
16+
;; movq 0x18(%r11), %r11
17+
;; addq $0x20, %r11
18+
;; cmpq %rsp, %r11
19+
;; ja 0x69
20+
;; 1c: movq %rdi, %r14
21+
;; subq $0x20, %rsp
22+
;; movq %rdi, 0x18(%rsp)
23+
;; movq %rsi, 0x10(%rsp)
24+
;; movss %xmm0, 0xc(%rsp)
25+
;; movss %xmm1, 8(%rsp)
26+
;; movss 8(%rsp), %xmm0
27+
;; movss 0xc(%rsp), %xmm1
28+
;; addss %xmm0, %xmm1
29+
;; ucomiss %xmm1, %xmm1
30+
;; jnp 0x5d
31+
;; 55: movss 0x13(%rip), %xmm1
32+
;; movaps %xmm1, %xmm0
33+
;; addq $0x20, %rsp
34+
;; popq %rbp
35+
;; retq
36+
;; 69: ud2
37+
;; 6b: addb %al, (%rax)
38+
;; 6d: addb %al, (%rax)
39+
;; 6f: addb %al, (%rax)
40+
;; 71: addb %al, %al
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
;;! target = "x86_64"
2+
;;! test = "winch"
3+
;;! flags = ["-Wnan-canonicalization", "-Ccranelift-has-avx"]
4+
5+
(module
6+
(func (param v128 v128) (result v128)
7+
local.get 0
8+
local.get 1
9+
f32x4.add
10+
)
11+
)
12+
;; wasm[0]::function[0]:
13+
;; pushq %rbp
14+
;; movq %rsp, %rbp
15+
;; movq 8(%rdi), %r11
16+
;; movq 0x18(%r11), %r11
17+
;; addq $0x30, %r11
18+
;; cmpq %rsp, %r11
19+
;; ja 0x6c
20+
;; 1c: movq %rdi, %r14
21+
;; subq $0x30, %rsp
22+
;; movq %rdi, 0x28(%rsp)
23+
;; movq %rsi, 0x20(%rsp)
24+
;; movdqu %xmm0, 0x10(%rsp)
25+
;; movdqu %xmm1, (%rsp)
26+
;; movdqu (%rsp), %xmm0
27+
;; movdqu 0x10(%rsp), %xmm1
28+
;; vaddps %xmm0, %xmm1, %xmm1
29+
;; vcmpunordps %xmm1, %xmm1, %xmm15
30+
;; vandnps %xmm1, %xmm15, %xmm1
31+
;; vandps 0x15(%rip), %xmm15, %xmm15
32+
;; vorps %xmm1, %xmm15, %xmm1
33+
;; movdqa %xmm1, %xmm0
34+
;; addq $0x30, %rsp
35+
;; popq %rbp
36+
;; retq
37+
;; 6c: ud2
38+
;; 6e: addb %al, (%rax)
39+
;; 70: addb %al, (%rax)
40+
;; 72: sarb $0, (%rdi)
41+
;; 76: sarb $0, (%rdi)
42+
;; 7a: sarb $0, (%rdi)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
;;! target = "x86_64"
2+
;;! test = "winch"
3+
;;! flags = "-Wnan-canonicalization"
4+
5+
(module
6+
(func (param f64 f64) (result f64)
7+
local.get 0
8+
local.get 1
9+
f64.div
10+
)
11+
)
12+
;; wasm[0]::function[0]:
13+
;; pushq %rbp
14+
;; movq %rsp, %rbp
15+
;; movq 8(%rdi), %r11
16+
;; movq 0x18(%r11), %r11
17+
;; addq $0x20, %r11
18+
;; cmpq %rsp, %r11
19+
;; ja 0x68
20+
;; 1c: movq %rdi, %r14
21+
;; subq $0x20, %rsp
22+
;; movq %rdi, 0x18(%rsp)
23+
;; movq %rsi, 0x10(%rsp)
24+
;; movsd %xmm0, 8(%rsp)
25+
;; movsd %xmm1, (%rsp)
26+
;; movsd (%rsp), %xmm0
27+
;; movsd 8(%rsp), %xmm1
28+
;; divsd %xmm0, %xmm1
29+
;; ucomisd %xmm1, %xmm1
30+
;; jnp 0x5c
31+
;; 54: movsd 0x14(%rip), %xmm1
32+
;; movaps %xmm1, %xmm0
33+
;; addq $0x20, %rsp
34+
;; popq %rbp
35+
;; retq
36+
;; 68: ud2
37+
;; 6a: addb %al, (%rax)
38+
;; 6c: addb %al, (%rax)
39+
;; 6e: addb %al, (%rax)
40+
;; 70: addb %al, (%rax)
41+
;; 72: addb %al, (%rax)
42+
;; 74: addb %al, (%rax)
43+
;; 76: clc
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
;;! nan_canonicalization = true
2+
3+
;; Scalar counterpart to simd/canonicalize-nan.wast.
4+
5+
(module
6+
(func (export "f32.add") (param f32 f32) (result f32)
7+
local.get 0
8+
local.get 1
9+
f32.add)
10+
(func (export "f32.sub") (param f32 f32) (result f32)
11+
local.get 0
12+
local.get 1
13+
f32.sub)
14+
(func (export "f32.mul") (param f32 f32) (result f32)
15+
local.get 0
16+
local.get 1
17+
f32.mul)
18+
(func (export "f32.div") (param f32 f32) (result f32)
19+
local.get 0
20+
local.get 1
21+
f32.div)
22+
(func (export "f32.min") (param f32 f32) (result f32)
23+
local.get 0
24+
local.get 1
25+
f32.min)
26+
(func (export "f32.max") (param f32 f32) (result f32)
27+
local.get 0
28+
local.get 1
29+
f32.max)
30+
(func (export "f32.sqrt") (param f32) (result f32)
31+
local.get 0
32+
f32.sqrt)
33+
(func (export "f32.ceil") (param f32) (result f32)
34+
local.get 0
35+
f32.ceil)
36+
(func (export "f32.floor") (param f32) (result f32)
37+
local.get 0
38+
f32.floor)
39+
(func (export "f32.trunc") (param f32) (result f32)
40+
local.get 0
41+
f32.trunc)
42+
(func (export "f32.nearest") (param f32) (result f32)
43+
local.get 0
44+
f32.nearest)
45+
46+
(func (export "f64.add") (param f64 f64) (result f64)
47+
local.get 0
48+
local.get 1
49+
f64.add)
50+
(func (export "f64.sub") (param f64 f64) (result f64)
51+
local.get 0
52+
local.get 1
53+
f64.sub)
54+
(func (export "f64.mul") (param f64 f64) (result f64)
55+
local.get 0
56+
local.get 1
57+
f64.mul)
58+
(func (export "f64.div") (param f64 f64) (result f64)
59+
local.get 0
60+
local.get 1
61+
f64.div)
62+
(func (export "f64.min") (param f64 f64) (result f64)
63+
local.get 0
64+
local.get 1
65+
f64.min)
66+
(func (export "f64.max") (param f64 f64) (result f64)
67+
local.get 0
68+
local.get 1
69+
f64.max)
70+
(func (export "f64.sqrt") (param f64) (result f64)
71+
local.get 0
72+
f64.sqrt)
73+
(func (export "f64.ceil") (param f64) (result f64)
74+
local.get 0
75+
f64.ceil)
76+
(func (export "f64.floor") (param f64) (result f64)
77+
local.get 0
78+
f64.floor)
79+
(func (export "f64.trunc") (param f64) (result f64)
80+
local.get 0
81+
f64.trunc)
82+
(func (export "f64.nearest") (param f64) (result f64)
83+
local.get 0
84+
f64.nearest)
85+
86+
(func (export "reinterpret-and-demote") (param i64) (result i32)
87+
local.get 0
88+
f64.reinterpret_i64
89+
f32.demote_f64
90+
i32.reinterpret_f32)
91+
(func (export "reinterpret-and-promote") (param i32) (result i64)
92+
local.get 0
93+
f32.reinterpret_i32
94+
f64.promote_f32
95+
i64.reinterpret_f64)
96+
97+
;; Expose raw bits of 0/0 to verify exact canonical NaN bit patterns.
98+
(func (export "f32.div-nan-bits") (result i32)
99+
f32.const 0
100+
f32.const 0
101+
f32.div
102+
i32.reinterpret_f32)
103+
(func (export "f64.div-nan-bits") (result i64)
104+
f64.const 0
105+
f64.const 0
106+
f64.div
107+
i64.reinterpret_f64)
108+
)
109+
110+
;; Exact bit patterns: canonical f32 NaN = 0x7fc00000, f64 = 0x7ff8000000000000
111+
(assert_return (invoke "f32.div-nan-bits") (i32.const 0x7fc00000))
112+
(assert_return (invoke "f64.div-nan-bits") (i64.const 0x7ff8000000000000))
113+
114+
;; NaN-producing operations
115+
(assert_return (invoke "f32.div" (f32.const 0) (f32.const 0)) (f32.const nan:0x400000))
116+
(assert_return (invoke "f64.div" (f64.const 0) (f64.const 0)) (f64.const nan:0x8000000000000))
117+
(assert_return (invoke "f32.sqrt" (f32.const -1)) (f32.const nan:0x400000))
118+
(assert_return (invoke "f64.sqrt" (f64.const -1)) (f64.const nan:0x8000000000000))
119+
120+
;; NaN propagation through f32 arithmetic
121+
(assert_return (invoke "f32.add" (f32.const nan) (f32.const 1)) (f32.const nan:0x400000))
122+
(assert_return (invoke "f32.sub" (f32.const nan) (f32.const 1)) (f32.const nan:0x400000))
123+
(assert_return (invoke "f32.mul" (f32.const nan) (f32.const 1)) (f32.const nan:0x400000))
124+
(assert_return (invoke "f32.min" (f32.const nan) (f32.const 1)) (f32.const nan:0x400000))
125+
(assert_return (invoke "f32.max" (f32.const nan) (f32.const 1)) (f32.const nan:0x400000))
126+
127+
;; NaN propagation through f64 arithmetic
128+
(assert_return (invoke "f64.add" (f64.const nan) (f64.const 1)) (f64.const nan:0x8000000000000))
129+
(assert_return (invoke "f64.sub" (f64.const nan) (f64.const 1)) (f64.const nan:0x8000000000000))
130+
(assert_return (invoke "f64.mul" (f64.const nan) (f64.const 1)) (f64.const nan:0x8000000000000))
131+
(assert_return (invoke "f64.min" (f64.const nan) (f64.const 1)) (f64.const nan:0x8000000000000))
132+
(assert_return (invoke "f64.max" (f64.const nan) (f64.const 1)) (f64.const nan:0x8000000000000))
133+
134+
;; Rounding NaN (f32)
135+
(assert_return (invoke "f32.ceil" (f32.const nan)) (f32.const nan:0x400000))
136+
(assert_return (invoke "f32.floor" (f32.const nan)) (f32.const nan:0x400000))
137+
(assert_return (invoke "f32.trunc" (f32.const nan)) (f32.const nan:0x400000))
138+
(assert_return (invoke "f32.nearest" (f32.const nan)) (f32.const nan:0x400000))
139+
140+
;; Rounding NaN (f64)
141+
(assert_return (invoke "f64.ceil" (f64.const nan)) (f64.const nan:0x8000000000000))
142+
(assert_return (invoke "f64.floor" (f64.const nan)) (f64.const nan:0x8000000000000))
143+
(assert_return (invoke "f64.trunc" (f64.const nan)) (f64.const nan:0x8000000000000))
144+
(assert_return (invoke "f64.nearest" (f64.const nan)) (f64.const nan:0x8000000000000))
145+
146+
;; Demote/promote with non-canonical NaN bit patterns
147+
(assert_return (invoke "reinterpret-and-demote" (i64.const 0xfffefdfccccdcecf)) (i32.const 0x7fc00000))
148+
(assert_return (invoke "reinterpret-and-promote" (i32.const 0xfffefdfc)) (i64.const 0x7ff8000000000000))
149+
150+
;; Normal values pass through unchanged
151+
(assert_return (invoke "f32.add" (f32.const 1) (f32.const 2)) (f32.const 3))
152+
(assert_return (invoke "f64.div" (f64.const 10) (f64.const 2)) (f64.const 5))
153+
(assert_return (invoke "f32.sqrt" (f32.const 4)) (f32.const 2))

winch/codegen/src/isa/aarch64/masm.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub(crate) struct MacroAssembler {
5959
ptr_size: OperandSize,
6060
/// Scratch register scope.
6161
scratch_scope: RegAlloc,
62+
/// Shared flags.
63+
shared_flags: settings::Flags,
6264
}
6365

6466
impl MacroAssembler {
@@ -71,10 +73,11 @@ impl MacroAssembler {
7173
Ok(Self {
7274
sp_max: 0,
7375
stack_max_use_add: None,
74-
asm: Assembler::new(shared_flags, isa_flags),
76+
asm: Assembler::new(shared_flags.clone(), isa_flags),
7577
sp_offset: 0u32,
7678
ptr_size: ptr_type_from_ptr_size(ptr_size.size()).try_into()?,
7779
scratch_scope: RegAlloc::from(scratch_gpr_bitset(), scratch_fpr_bitset()),
80+
shared_flags,
7881
})
7982
}
8083

@@ -713,6 +716,43 @@ impl Masm for MacroAssembler {
713716
Ok(())
714717
}
715718

719+
fn maybe_canonicalize_nan(&mut self, reg: WritableReg, size: OperandSize) -> Result<()> {
720+
if !self.shared_flags.enable_nan_canonicalization() {
721+
return Ok(());
722+
}
723+
724+
let done_label = self.asm.buffer_mut().get_label();
725+
726+
self.asm.fcmp(reg.to_reg(), reg.to_reg(), size);
727+
self.asm.jmp_if(Cond::Vc, done_label);
728+
729+
let canonical_nan = match size {
730+
OperandSize::S32 => crate::masm::CANONICAL_NAN_F32,
731+
OperandSize::S64 => crate::masm::CANONICAL_NAN_F64,
732+
_ => bail!(CodeGenError::unexpected_operand_size()),
733+
};
734+
let constant = self.asm.add_constant(canonical_nan);
735+
self.asm.uload(
736+
inst::AMode::Const { addr: constant },
737+
reg,
738+
size,
739+
TRUSTED_FLAGS,
740+
);
741+
742+
self.asm
743+
.buffer_mut()
744+
.bind_label(done_label, &mut Default::default());
745+
Ok(())
746+
}
747+
748+
fn maybe_canonicalize_v128_nan(
749+
&mut self,
750+
_reg: WritableReg,
751+
_lane_size: OperandSize,
752+
) -> Result<()> {
753+
bail!(CodeGenError::unimplemented_masm_instruction())
754+
}
755+
716756
fn and(&mut self, dst: WritableReg, lhs: Reg, rhs: RegImm, size: OperandSize) -> Result<()> {
717757
match (rhs, lhs, dst) {
718758
(RegImm::Imm(v), rn, rd) => {

0 commit comments

Comments
 (0)