Add select and conditional operations to SIMD wrappers#10287
Open
copybara-service[bot] wants to merge 1 commit into
Open
Add select and conditional operations to SIMD wrappers#10287copybara-service[bot] wants to merge 1 commit into
select and conditional operations to SIMD wrappers#10287copybara-service[bot] wants to merge 1 commit into
Conversation
To make this work across the variety of SIMD ISA conventions for this, mask types should just be `auto`. They can't be stored to memory. If we need `select` to work on conditions with a different size type than the values, we should be able to handle that with more overloads (and some conversions), but we don't need it yet. This removes the workarounds that avoided the need for this until now: `select_greater_than`, `copynan`, `kahan_sum`. This was partly AI generated, here is what it had to say: ✦ I have completed the implementation of boolean operators and conditional selection across the entire SIMD abstraction layer. Here is a summary of the work: 1. SIMD Core & Generic Layer * Boolean Operators: Implemented ==, !=, <, <=, >, >= for all vector types. These now use auto return types to accommodate varying architecture-specific mask representations (e.g., integer vectors on SSE/AVX vs. __mmaskN on AVX-512). * Conditional Selection: Added a unified select(mask, true_val, false_val) intrinsic. * Predicates: Implemented isnan, isinf, and isfinite for floating-point vectors. * Recursive Generics: Updated generic.inc to support recursive implementation of these operators for wide vectors (e.g., f32x16 on AVX2) by defining lo()/hi()/concat() helpers for all vector and mask types. 2. Architecture-Specific Implementations * x86 AVX-512: Utilized native mask registers (__mmask16, etc.) and _mm512_mask_blend_* instructions. Implemented mask-specific glue logic for recursive templates. * x86 SSE/AVX: Implemented comparisons returning signed integer vectors. Added s64x2 and s64x4 specializations to ensure 64-bit comparisons (for doubles) maintain the correct element count in the resulting mask. * ARM NEON: Leveraged vceqq, vcgtq, and vbslq (bitselect) for mask operations. Added 64-bit integer vector support for ARM64. * WASM SIMD128: Integrated wasm_v128_bitselect and standard WASM comparison intrinsics. 3. Kernel Compiler & Intrinsics * Intrinsic Support: Updated the elementwise kernel compiler (compiler.py) to support select, isnan, isinf, and isfinite. * Type Deduction: Modified the compiler to emit auto for comparison results, allowing the C++ compiler to handle architecture-specific mask types. * Kernel Refactoring: Migrated existing kernels (e.g., exp, sigmoid) away from the legacy copynan helper to the more idiomatic select(isnan(x), ...) pattern. 4. Verification & Cleanup * Comprehensive Testing: Expanded the SIMD test suite in base/simd/test/ to verify comparisons, predicates, and selection logic across all architectures. * Code Quality: Resolved header inclusion ordering issues, fixed redefinition errors between SSE2 and SSE4.1, and performed a significant cleanup of bitwise operator consistency in the AVX headers. Status: All tasks in PLAN.md are complete, and all SIMD test targets (x86_avx512, x86_avx2, x86_avx, x86_sse2, arm_neon, wasm_simd128) are passing. I expected this to be a no-op in terms of performance/generated code, but it improves some benchmarks: ``` bench/sum_squared_k1_bf16_fp32_avx512/real_time [256x256] 5.520µ ± 4% 5.483µ ± 3% ~ (p=0.937 n=6) bench/sum_squared_kn_bf16_fp32_avx512/real_time [256x256] 4.766µ ± 3% 5.053µ ± 11% ~ (p=0.065 n=6) bench/sum_squared_k1_fp16_fp32_avx512/real_time [256x256] 4.495µ ± 3% 4.571µ ± 5% ~ (p=0.310 n=6) bench/sum_squared_kn_fp16_fp32_avx512/real_time [256x256] 4.157µ ± 1% 4.181µ ± 17% ~ (p=0.937 n=6) bench/sum_squared_k1_fp32_avx512/real_time [256x256] 5.506µ ± 15% 5.286µ ± 7% ~ (p=0.310 n=6) bench/sum_squared_kn_fp32_avx512/real_time [256x256] 4.770µ ± 9% 4.691µ ± 17% ~ (p=0.937 n=6) bench/sum_squared_k1_fp64_avx512/real_time [256x256] 9.632µ ± 10% 8.636µ ± 2% -10.34% (p=0.015 n=6) bench/sum_squared_kn_fp64_avx512/real_time [256x256] 10.36µ ± 6% 10.21µ ± 9% ~ (p=0.699 n=6) bench/sum_squared_k1_bf16_fp32_avx2/real_time [256x256] 8.651µ ± 8% 7.314µ ± 4% -15.45% (p=0.002 n=6) bench/sum_squared_kn_bf16_fp32_avx2/real_time [256x256] 7.510µ ± 3% 6.949µ ± 3% -7.47% (p=0.002 n=6) bench/sum_squared_k1_fp32_avx/real_time [256x256] 7.076µ ± 3% 6.491µ ± 3% -8.27% (p=0.002 n=6) bench/sum_squared_kn_fp32_avx/real_time [256x256] 6.368µ ± 4% 5.859µ ± 3% -7.99% (p=0.002 n=6) bench/sum_squared_k1_fp64_avx/real_time [256x256] 12.28µ ± 2% 11.31µ ± 4% -7.89% (p=0.002 n=6) bench/sum_squared_kn_fp64_avx/real_time [256x256] 12.72µ ± 4% 11.61µ ± 9% -8.76% (p=0.009 n=6) bench/sum_squared_k1_fp16_fp32_f16c/real_time [256x256] 7.480µ ± 4% 6.799µ ± 2% -9.11% (p=0.002 n=6) bench/sum_squared_kn_fp16_fp32_f16c/real_time [256x256] 6.445µ ± 2% 6.163µ ± 12% ~ (p=0.065 n=6) bench/sum_squared_k1_fp32_sse2/real_time [256x256] 10.97µ ± 4% 10.31µ ± 5% -6.05% (p=0.009 n=6) bench/sum_squared_kn_fp32_sse2/real_time [256x256] 10.42µ ± 3% 10.21µ ± 3% ~ (p=0.065 n=6) bench/sum_squared_k1_bf16_fp32_sse2/real_time [256x256] 12.85µ ± 5% 12.31µ ± 3% -4.20% (p=0.004 n=6) bench/sum_squared_kn_bf16_fp32_sse2/real_time [256x256] 11.56µ ± 2% 11.25µ ± 4% ~ (p=0.240 n=6) bench/sum_squared_k1_fp64/real_time [256x256] 24.13µ ± 3% 22.98µ ± 19% ~ (p=0.065 n=6) bench/sum_squared_kn_fp64/real_time [256x256] 21.80µ ± 4% 20.33µ ± 16% ~ (p=0.065 n=6) bench/sum_squared_k1_fp16_fp32/real_time [256x256] 183.7µ ± 1% 181.1µ ± 2% ~ (p=0.093 n=6) bench/sum_squared_kn_fp16_fp32/real_time [256x256] 50.73µ ± 7% 47.96µ ± 2% -5.45% (p=0.002 n=6) geomean 9.443µ 9.233µ -2.22% ``` But some kernels slower: ``` name time/op time/op vs base bench_reference/exp_float/m:1/n:4096/real_time 53.03µ ± 1% 53.24µ ± 2% ~ (p=0.485 n=6) bench_reference/exp_double/m:1/n:4096/real_time 64.92µ ± 1% 65.43µ ± 2% ~ (p=0.180 n=6) bench_reference/sigmoid_float/m:1/n:4096/real_time 37.11µ ± 1% 37.41µ ± 2% ~ (p=0.394 n=6) bench_reference/sigmoid_double/m:1/n:4096/real_time 72.69µ ± 20% 72.44µ ± 1% ~ (p=0.699 n=6) bench/exp_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1.844µ ± 1% 1.872µ ± 3% +1.51% (p=0.015 n=6) bench/exp_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time 4.816µ ± 6% 4.743µ ± 2% ~ (p=0.394 n=6) bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1.841µ ± 3% 1.827µ ± 1% ~ (p=0.394 n=6) bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time 4.549µ ± 7% 4.619µ ± 3% ~ (p=0.093 n=6) bench/exp_fp32_1x32_x86_avx2_fma3/m:1/n:4096/real_time 2.481µ ± 3% 2.575µ ± 3% +3.77% (p=0.002 n=6) bench/exp_fp64_1x16_x86_avx2_fma3/m:1/n:4096/real_time 6.648µ ± 3% 6.834µ ± 3% +2.80% (p=0.026 n=6) bench/exp_fp32_1x16_x86_avx2/m:1/n:4096/real_time 3.777µ ± 2% 3.876µ ± 3% +2.60% (p=0.015 n=6) bench/exp_fp64_1x16_x86_avx2/m:1/n:4096/real_time 9.753µ ± 3% 9.919µ ± 3% ~ (p=0.180 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time 4.010µ ± 3% 3.910µ ± 2% ~ (p=0.093 n=6) bench/sigmoid_fp64_1x8_x86_avx2/m:1/n:4096/real_time 10.65µ ± 3% 10.66µ ± 2% ~ (p=0.589 n=6) bench/exp_fp32_1x16_x86_sse2/m:1/n:4096/real_time 5.544µ ± 1% 5.770µ ± 1% +4.07% (p=0.002 n=6) bench/exp_fp64_1x8_x86_sse2/m:1/n:4096/real_time 16.14µ ± 7% 16.53µ ± 3% ~ (p=0.310 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time 5.792µ ± 4% 5.917µ ± 4% ~ (p=0.394 n=6) bench/sigmoid_fp64_1x8_x86_sse2/m:1/n:4096/real_time 17.36µ ± 3% 17.72µ ± 2% ~ (p=0.132 n=6) geomean 9.036µ 9.144µ +1.19% ``` PiperOrigin-RevId: 918101088
6a0b23b to
883c665
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add
selectand conditional operations to SIMD wrappersTo make this work across the variety of SIMD ISA conventions for this, mask types should just be
auto. They can't be stored to memory.If we need
selectto work on conditions with a different size type than the values, we should be able to handle that with more overloads (and some conversions), but we don't need it yet.This removes the workarounds that avoided the need for this until now:
select_greater_than,copynan,kahan_sum.This was partly AI generated, here is what it had to say:
✦ I have completed the implementation of boolean operators and conditional selection across the entire SIMD abstraction layer. Here is a summary of the work:
Status: All tasks in PLAN.md are complete, and all SIMD test targets (x86_avx512, x86_avx2, x86_avx, x86_sse2, arm_neon, wasm_simd128) are passing.
I expected this to be a no-op in terms of performance/generated code, but it improves some benchmarks:
But some kernels slower: