Skip to content

Add select and conditional operations to SIMD wrappers#10287

Open
copybara-service[bot] wants to merge 1 commit into
masterfrom
test_918101088
Open

Add select and conditional operations to SIMD wrappers#10287
copybara-service[bot] wants to merge 1 commit into
masterfrom
test_918101088

Conversation

@copybara-service
Copy link
Copy Markdown
Contributor

Add select and conditional operations to SIMD wrappers

To make this work across the variety of SIMD ISA conventions for this, mask types should just be auto. They can't be stored to memory.

If we need select to work on conditions with a different size type than the values, we should be able to handle that with more overloads (and some conversions), but we don't need it yet.

This removes the workarounds that avoided the need for this until now: select_greater_than, copynan, kahan_sum.

This was partly AI generated, here is what it had to say:

✦ I have completed the implementation of boolean operators and conditional selection across the entire SIMD abstraction layer. Here is a summary of the work:

  1. SIMD Core & Generic Layer
  • Boolean Operators: Implemented ==, !=, <, <=, >, >= for all vector types. These now use auto return types to accommodate varying architecture-specific mask representations (e.g., integer vectors on SSE/AVX vs. __mmaskN on AVX-512).
  • Conditional Selection: Added a unified select(mask, true_val, false_val) intrinsic.
  • Predicates: Implemented isnan, isinf, and isfinite for floating-point vectors.
  • Recursive Generics: Updated generic.inc to support recursive implementation of these operators for wide vectors (e.g., f32x16 on AVX2) by defining lo()/hi()/concat() helpers for all vector and mask types.
  1. Architecture-Specific Implementations
  • x86 AVX-512: Utilized native mask registers (__mmask16, etc.) and mm512_mask_blend* instructions. Implemented mask-specific glue logic for recursive templates.
  • x86 SSE/AVX: Implemented comparisons returning signed integer vectors. Added s64x2 and s64x4 specializations to ensure 64-bit comparisons (for doubles) maintain the correct element count in the resulting mask.
  • ARM NEON: Leveraged vceqq, vcgtq, and vbslq (bitselect) for mask operations. Added 64-bit integer vector support for ARM64.
  • WASM SIMD128: Integrated wasm_v128_bitselect and standard WASM comparison intrinsics.
  1. Kernel Compiler & Intrinsics
  • Intrinsic Support: Updated the elementwise kernel compiler (compiler.py) to support select, isnan, isinf, and isfinite.
  • Type Deduction: Modified the compiler to emit auto for comparison results, allowing the C++ compiler to handle architecture-specific mask types.
  • Kernel Refactoring: Migrated existing kernels (e.g., exp, sigmoid) away from the legacy copynan helper to the more idiomatic select(isnan(x), ...) pattern.
  1. Verification & Cleanup
  • Comprehensive Testing: Expanded the SIMD test suite in base/simd/test/ to verify comparisons, predicates, and selection logic across all architectures.
  • Code Quality: Resolved header inclusion ordering issues, fixed redefinition errors between SSE2 and SSE4.1, and performed a significant cleanup of bitwise operator consistency in the AVX headers.

Status: All tasks in PLAN.md are complete, and all SIMD test targets (x86_avx512, x86_avx2, x86_avx, x86_sse2, arm_neon, wasm_simd128) are passing.

I expected this to be a no-op in terms of performance/generated code, but it improves some benchmarks:

bench/sum_squared_k1_bf16_fp32_avx512/real_time     [256x256] 5.520µ ±  4%    5.483µ ±  3%        ~ (p=0.937 n=6)
bench/sum_squared_kn_bf16_fp32_avx512/real_time     [256x256] 4.766µ ±  3%    5.053µ ± 11%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp16_fp32_avx512/real_time     [256x256] 4.495µ ±  3%    4.571µ ±  5%        ~ (p=0.310 n=6)
bench/sum_squared_kn_fp16_fp32_avx512/real_time     [256x256] 4.157µ ±  1%    4.181µ ± 17%        ~ (p=0.937 n=6)
bench/sum_squared_k1_fp32_avx512/real_time          [256x256] 5.506µ ± 15%    5.286µ ±  7%        ~ (p=0.310 n=6)
bench/sum_squared_kn_fp32_avx512/real_time          [256x256] 4.770µ ±  9%    4.691µ ± 17%        ~ (p=0.937 n=6)
bench/sum_squared_k1_fp64_avx512/real_time          [256x256] 9.632µ ± 10%    8.636µ ±  2%  -10.34% (p=0.015 n=6)
bench/sum_squared_kn_fp64_avx512/real_time          [256x256] 10.36µ ±  6%    10.21µ ±  9%        ~ (p=0.699 n=6)
bench/sum_squared_k1_bf16_fp32_avx2/real_time       [256x256] 8.651µ ±  8%    7.314µ ±  4%  -15.45% (p=0.002 n=6)
bench/sum_squared_kn_bf16_fp32_avx2/real_time       [256x256] 7.510µ ±  3%    6.949µ ±  3%   -7.47% (p=0.002 n=6)
bench/sum_squared_k1_fp32_avx/real_time             [256x256] 7.076µ ±  3%    6.491µ ±  3%   -8.27% (p=0.002 n=6)
bench/sum_squared_kn_fp32_avx/real_time             [256x256] 6.368µ ±  4%    5.859µ ±  3%   -7.99% (p=0.002 n=6)
bench/sum_squared_k1_fp64_avx/real_time             [256x256] 12.28µ ±  2%    11.31µ ±  4%   -7.89% (p=0.002 n=6)
bench/sum_squared_kn_fp64_avx/real_time             [256x256] 12.72µ ±  4%    11.61µ ±  9%   -8.76% (p=0.009 n=6)
bench/sum_squared_k1_fp16_fp32_f16c/real_time       [256x256] 7.480µ ±  4%    6.799µ ±  2%   -9.11% (p=0.002 n=6)
bench/sum_squared_kn_fp16_fp32_f16c/real_time       [256x256] 6.445µ ±  2%    6.163µ ± 12%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp32_sse2/real_time            [256x256] 10.97µ ±  4%    10.31µ ±  5%   -6.05% (p=0.009 n=6)
bench/sum_squared_kn_fp32_sse2/real_time            [256x256] 10.42µ ±  3%    10.21µ ±  3%        ~ (p=0.065 n=6)
bench/sum_squared_k1_bf16_fp32_sse2/real_time       [256x256] 12.85µ ±  5%    12.31µ ±  3%   -4.20% (p=0.004 n=6)
bench/sum_squared_kn_bf16_fp32_sse2/real_time       [256x256] 11.56µ ±  2%    11.25µ ±  4%        ~ (p=0.240 n=6)
bench/sum_squared_k1_fp64/real_time                 [256x256] 24.13µ ±  3%    22.98µ ± 19%        ~ (p=0.065 n=6)
bench/sum_squared_kn_fp64/real_time                 [256x256] 21.80µ ±  4%    20.33µ ± 16%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp16_fp32/real_time            [256x256] 183.7µ ±  1%    181.1µ ±  2%        ~ (p=0.093 n=6)
bench/sum_squared_kn_fp16_fp32/real_time            [256x256] 50.73µ ±  7%    47.96µ ±  2%   -5.45% (p=0.002 n=6)
geomean                                                       9.443µ          9.233µ         -2.22%

But some kernels slower:

name                                                                time/op       time/op     vs base              
bench_reference/exp_float/m:1/n:4096/real_time                      53.03µ ±  1%   53.24µ ± 2%       ~ (p=0.485 n=6)
bench_reference/exp_double/m:1/n:4096/real_time                     64.92µ ±  1%   65.43µ ± 2%       ~ (p=0.180 n=6)
bench_reference/sigmoid_float/m:1/n:4096/real_time                  37.11µ ±  1%   37.41µ ± 2%       ~ (p=0.394 n=6)
bench_reference/sigmoid_double/m:1/n:4096/real_time                 72.69µ ± 20%   72.44µ ± 1%       ~ (p=0.699 n=6)
bench/exp_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time       1.844µ ±  1%   1.872µ ± 3%  +1.51% (p=0.015 n=6)
bench/exp_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time       4.816µ ±  6%   4.743µ ± 2%       ~ (p=0.394 n=6)
bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time   1.841µ ±  3%   1.827µ ± 1%       ~ (p=0.394 n=6)
bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time   4.549µ ±  7%   4.619µ ± 3%       ~ (p=0.093 n=6)
bench/exp_fp32_1x32_x86_avx2_fma3/m:1/n:4096/real_time              2.481µ ±  3%   2.575µ ± 3%  +3.77% (p=0.002 n=6)
bench/exp_fp64_1x16_x86_avx2_fma3/m:1/n:4096/real_time              6.648µ ±  3%   6.834µ ± 3%  +2.80% (p=0.026 n=6)
bench/exp_fp32_1x16_x86_avx2/m:1/n:4096/real_time                   3.777µ ±  2%   3.876µ ± 3%  +2.60% (p=0.015 n=6)
bench/exp_fp64_1x16_x86_avx2/m:1/n:4096/real_time                   9.753µ ±  3%   9.919µ ± 3%       ~ (p=0.180 n=6)
bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time               4.010µ ±  3%   3.910µ ± 2%       ~ (p=0.093 n=6)
bench/sigmoid_fp64_1x8_x86_avx2/m:1/n:4096/real_time                10.65µ ±  3%   10.66µ ± 2%       ~ (p=0.589 n=6)
bench/exp_fp32_1x16_x86_sse2/m:1/n:4096/real_time                   5.544µ ±  1%   5.770µ ± 1%  +4.07% (p=0.002 n=6)
bench/exp_fp64_1x8_x86_sse2/m:1/n:4096/real_time                    16.14µ ±  7%   16.53µ ± 3%       ~ (p=0.310 n=6)
bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time               5.792µ ±  4%   5.917µ ± 4%       ~ (p=0.394 n=6)
bench/sigmoid_fp64_1x8_x86_sse2/m:1/n:4096/real_time                17.36µ ±  3%   17.72µ ± 2%       ~ (p=0.132 n=6)
geomean                                                             9.036µ         9.144µ       +1.19%

To make this work across the variety of SIMD ISA conventions for this, mask types should just be `auto`. They can't be stored to memory.

If we need `select` to work on conditions with a different size type than the values, we should be able to handle that with more overloads (and some conversions), but we don't need it yet.

This removes the workarounds that avoided the need for this until now: `select_greater_than`, `copynan`, `kahan_sum`.

This was partly AI generated, here is what it had to say:

✦ I have completed the implementation of boolean operators and conditional selection across the entire SIMD abstraction layer. Here is a summary of the work:

  1. SIMD Core & Generic Layer
   * Boolean Operators: Implemented ==, !=, <, <=, >, >= for all vector types. These now use auto return types to accommodate varying architecture-specific mask representations (e.g., integer vectors on SSE/AVX vs. __mmaskN on AVX-512).
   * Conditional Selection: Added a unified select(mask, true_val, false_val) intrinsic.
   * Predicates: Implemented isnan, isinf, and isfinite for floating-point vectors.
   * Recursive Generics: Updated generic.inc to support recursive implementation of these operators for wide vectors (e.g., f32x16 on AVX2) by defining lo()/hi()/concat() helpers for all vector and mask types.

  2. Architecture-Specific Implementations
   * x86 AVX-512: Utilized native mask registers (__mmask16, etc.) and _mm512_mask_blend_* instructions. Implemented mask-specific glue logic for recursive templates.
   * x86 SSE/AVX: Implemented comparisons returning signed integer vectors. Added s64x2 and s64x4 specializations to ensure 64-bit comparisons (for doubles) maintain the correct element count in the resulting mask.
   * ARM NEON: Leveraged vceqq, vcgtq, and vbslq (bitselect) for mask operations. Added 64-bit integer vector support for ARM64.
   * WASM SIMD128: Integrated wasm_v128_bitselect and standard WASM comparison intrinsics.

  3. Kernel Compiler & Intrinsics
   * Intrinsic Support: Updated the elementwise kernel compiler (compiler.py) to support select, isnan, isinf, and isfinite.
   * Type Deduction: Modified the compiler to emit auto for comparison results, allowing the C++ compiler to handle architecture-specific mask types.
   * Kernel Refactoring: Migrated existing kernels (e.g., exp, sigmoid) away from the legacy copynan helper to the more idiomatic select(isnan(x), ...) pattern.

  4. Verification & Cleanup
   * Comprehensive Testing: Expanded the SIMD test suite in base/simd/test/ to verify comparisons, predicates, and selection logic across all architectures.
   * Code Quality: Resolved header inclusion ordering issues, fixed redefinition errors between SSE2 and SSE4.1, and performed a significant cleanup of bitwise operator consistency in the AVX headers.

  Status: All tasks in PLAN.md are complete, and all SIMD test targets (x86_avx512, x86_avx2, x86_avx, x86_sse2, arm_neon, wasm_simd128) are passing.

I expected this to be a no-op in terms of performance/generated code, but it improves some benchmarks:
```
bench/sum_squared_k1_bf16_fp32_avx512/real_time     [256x256] 5.520µ ±  4%    5.483µ ±  3%        ~ (p=0.937 n=6)
bench/sum_squared_kn_bf16_fp32_avx512/real_time     [256x256] 4.766µ ±  3%    5.053µ ± 11%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp16_fp32_avx512/real_time     [256x256] 4.495µ ±  3%    4.571µ ±  5%        ~ (p=0.310 n=6)
bench/sum_squared_kn_fp16_fp32_avx512/real_time     [256x256] 4.157µ ±  1%    4.181µ ± 17%        ~ (p=0.937 n=6)
bench/sum_squared_k1_fp32_avx512/real_time          [256x256] 5.506µ ± 15%    5.286µ ±  7%        ~ (p=0.310 n=6)
bench/sum_squared_kn_fp32_avx512/real_time          [256x256] 4.770µ ±  9%    4.691µ ± 17%        ~ (p=0.937 n=6)
bench/sum_squared_k1_fp64_avx512/real_time          [256x256] 9.632µ ± 10%    8.636µ ±  2%  -10.34% (p=0.015 n=6)
bench/sum_squared_kn_fp64_avx512/real_time          [256x256] 10.36µ ±  6%    10.21µ ±  9%        ~ (p=0.699 n=6)
bench/sum_squared_k1_bf16_fp32_avx2/real_time       [256x256] 8.651µ ±  8%    7.314µ ±  4%  -15.45% (p=0.002 n=6)
bench/sum_squared_kn_bf16_fp32_avx2/real_time       [256x256] 7.510µ ±  3%    6.949µ ±  3%   -7.47% (p=0.002 n=6)
bench/sum_squared_k1_fp32_avx/real_time             [256x256] 7.076µ ±  3%    6.491µ ±  3%   -8.27% (p=0.002 n=6)
bench/sum_squared_kn_fp32_avx/real_time             [256x256] 6.368µ ±  4%    5.859µ ±  3%   -7.99% (p=0.002 n=6)
bench/sum_squared_k1_fp64_avx/real_time             [256x256] 12.28µ ±  2%    11.31µ ±  4%   -7.89% (p=0.002 n=6)
bench/sum_squared_kn_fp64_avx/real_time             [256x256] 12.72µ ±  4%    11.61µ ±  9%   -8.76% (p=0.009 n=6)
bench/sum_squared_k1_fp16_fp32_f16c/real_time       [256x256] 7.480µ ±  4%    6.799µ ±  2%   -9.11% (p=0.002 n=6)
bench/sum_squared_kn_fp16_fp32_f16c/real_time       [256x256] 6.445µ ±  2%    6.163µ ± 12%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp32_sse2/real_time            [256x256] 10.97µ ±  4%    10.31µ ±  5%   -6.05% (p=0.009 n=6)
bench/sum_squared_kn_fp32_sse2/real_time            [256x256] 10.42µ ±  3%    10.21µ ±  3%        ~ (p=0.065 n=6)
bench/sum_squared_k1_bf16_fp32_sse2/real_time       [256x256] 12.85µ ±  5%    12.31µ ±  3%   -4.20% (p=0.004 n=6)
bench/sum_squared_kn_bf16_fp32_sse2/real_time       [256x256] 11.56µ ±  2%    11.25µ ±  4%        ~ (p=0.240 n=6)
bench/sum_squared_k1_fp64/real_time                 [256x256] 24.13µ ±  3%    22.98µ ± 19%        ~ (p=0.065 n=6)
bench/sum_squared_kn_fp64/real_time                 [256x256] 21.80µ ±  4%    20.33µ ± 16%        ~ (p=0.065 n=6)
bench/sum_squared_k1_fp16_fp32/real_time            [256x256] 183.7µ ±  1%    181.1µ ±  2%        ~ (p=0.093 n=6)
bench/sum_squared_kn_fp16_fp32/real_time            [256x256] 50.73µ ±  7%    47.96µ ±  2%   -5.45% (p=0.002 n=6)
geomean                                                       9.443µ          9.233µ         -2.22%
```
But some kernels slower:
```
name                                                                time/op       time/op     vs base
bench_reference/exp_float/m:1/n:4096/real_time                      53.03µ ±  1%   53.24µ ± 2%       ~ (p=0.485 n=6)
bench_reference/exp_double/m:1/n:4096/real_time                     64.92µ ±  1%   65.43µ ± 2%       ~ (p=0.180 n=6)
bench_reference/sigmoid_float/m:1/n:4096/real_time                  37.11µ ±  1%   37.41µ ± 2%       ~ (p=0.394 n=6)
bench_reference/sigmoid_double/m:1/n:4096/real_time                 72.69µ ± 20%   72.44µ ± 1%       ~ (p=0.699 n=6)
bench/exp_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time       1.844µ ±  1%   1.872µ ± 3%  +1.51% (p=0.015 n=6)
bench/exp_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time       4.816µ ±  6%   4.743µ ± 2%       ~ (p=0.394 n=6)
bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time   1.841µ ±  3%   1.827µ ± 1%       ~ (p=0.394 n=6)
bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time   4.549µ ±  7%   4.619µ ± 3%       ~ (p=0.093 n=6)
bench/exp_fp32_1x32_x86_avx2_fma3/m:1/n:4096/real_time              2.481µ ±  3%   2.575µ ± 3%  +3.77% (p=0.002 n=6)
bench/exp_fp64_1x16_x86_avx2_fma3/m:1/n:4096/real_time              6.648µ ±  3%   6.834µ ± 3%  +2.80% (p=0.026 n=6)
bench/exp_fp32_1x16_x86_avx2/m:1/n:4096/real_time                   3.777µ ±  2%   3.876µ ± 3%  +2.60% (p=0.015 n=6)
bench/exp_fp64_1x16_x86_avx2/m:1/n:4096/real_time                   9.753µ ±  3%   9.919µ ± 3%       ~ (p=0.180 n=6)
bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time               4.010µ ±  3%   3.910µ ± 2%       ~ (p=0.093 n=6)
bench/sigmoid_fp64_1x8_x86_avx2/m:1/n:4096/real_time                10.65µ ±  3%   10.66µ ± 2%       ~ (p=0.589 n=6)
bench/exp_fp32_1x16_x86_sse2/m:1/n:4096/real_time                   5.544µ ±  1%   5.770µ ± 1%  +4.07% (p=0.002 n=6)
bench/exp_fp64_1x8_x86_sse2/m:1/n:4096/real_time                    16.14µ ±  7%   16.53µ ± 3%       ~ (p=0.310 n=6)
bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time               5.792µ ±  4%   5.917µ ± 4%       ~ (p=0.394 n=6)
bench/sigmoid_fp64_1x8_x86_sse2/m:1/n:4096/real_time                17.36µ ±  3%   17.72µ ± 2%       ~ (p=0.132 n=6)
geomean                                                             9.036µ         9.144µ       +1.19%
```

PiperOrigin-RevId: 918101088
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant