Skip to content

[ET-VK] Fix softmax NaN and depthwise conv correctness bugs#17848

Merged
meta-codesync[bot] merged 1 commit intogh/SS-JIA/457/basefrom
gh/SS-JIA/457/head
Mar 4, 2026
Merged

[ET-VK] Fix softmax NaN and depthwise conv correctness bugs#17848
meta-codesync[bot] merged 1 commit intogh/SS-JIA/457/basefrom
gh/SS-JIA/457/head

Conversation

@SS-JIA
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA commented Mar 4, 2026

Stack from ghstack (oldest at bottom):

Fix three bugs causing incorrect output when running the edgeTAM model
with the Vulkan backend. Together these fixes bring the model from
producing all-NaN output to matching the reference within fp32 tolerance.

Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)

In softmax_packed_dim, each workgroup uses NWORKERS=4 threads to
collaboratively reduce along the packed dimension. Before the main loop,
each worker initializes max_elements by loading from texel index
tid.x. When NWORKERS exceeds the number of texels (e.g., a 12-element
dim has only 3 texels, but worker 3 tries to load texel index 3), the
load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the
cross-worker max reduction, so for any row where all actual values are
negative, the computed max becomes 0 instead of the true (negative) max.
Then exp(value - 0) underflows to 0 for all elements, giving
denominator=0 and NaN output.

Fixed by initializing max_elements = vec4(-3.402823e+38) (i.e.,
-FLT_MAX) so that workers with no valid texels contribute -inf to the
reduction. Also added a safe_denominator = max(denominator, 1e-37)
clamp as a secondary safety net against any remaining underflow edge
cases.

This affected the edgeTAM attention softmax over 12 key positions, where
~15% of query rows had all-negative attention scores and produced NaN.

Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)

Applied similar defensive fixes to softmax_nonpacked_dim:

  • Clamped denominator via max(denominators, vec4(1e-37)) to prevent
    0/0 = NaN if all exp values underflow.
  • Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels.
    This uses floatBitsToUint/uintBitsToFloat with exponent-bit
    masking rather than isnan() or x != x, which may not work reliably
    on all GPU drivers due to OpIsNan bugs and ordered comparison
    semantics.
  • Added memoryBarrierImage() after the output write loop to flush
    imageStore writes so they're visible to subsequent GPU operations.

Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)

The depthwise convolution code path in add_conv2d_node unconditionally
passed kernel parameters (stride, padding, dilation, etc.) via push
constants. However, the base conv2d_dw.glsl shader (used for non-3x3
and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these
parameters as UBOs at binding points 4–8, not as push constants. The
_output_tile shader variants do use push constants, so 3x3 and 5x5
depthwise convolutions worked correctly.

For 1x1 depthwise convolutions, the shader read from unbound UBOs,
getting zeros for stride, padding, dilation, and overlay_region. With
stride=0 and overlay_region=(0,0), the convolution loop never executed,
producing output equal to just the bias (effectively zero for small
biases).

Fixed by checking whether the selected shader name contains
_output_tile. If not, parameters are passed via UBOs (matching the
shader's declarations) instead of push constants.

Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)

The base conv2d_dw.glsl shader uses a fully 1D thread mapping where
gl_GlobalInvocationID.x encodes all three output dimensions:
pos.x = gid.x % W, pos.y = (gid.x / W) % H,
pos.z = gid.x / (W * H). The _output_tile variants use a 2D mapping
with spatial tiles in .x and channels in .y.

The conv2d_global_wg_size callback was dispatching all depthwise
shaders with workgroup size {W*H, C_packed, 1}, which is correct for
_output_tile but wrong for the base shader. With this size, all
threads have gid.x < W*H, so pos.z = gid.x / (W*H) = 0 — only
channel texel 0 (channels 0–3 out of e.g. 192) gets computed.

Fixed by dispatching {W*H*C_packed, 1, 1} for the base shader so
that gid.x ranges over all spatial × channel positions.

Differential Revision: D95217947

Fix three bugs causing incorrect output when running the edgeTAM model
with the Vulkan backend. Together these fixes bring the model from
producing all-NaN output to matching the reference within fp32 tolerance.

**Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)**

In `softmax_packed_dim`, each workgroup uses NWORKERS=4 threads to
collaboratively reduce along the packed dimension. Before the main loop,
each worker initializes `max_elements` by loading from texel index
`tid.x`. When NWORKERS exceeds the number of texels (e.g., a 12-element
dim has only 3 texels, but worker 3 tries to load texel index 3), the
load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the
cross-worker max reduction, so for any row where all actual values are
negative, the computed max becomes 0 instead of the true (negative) max.
Then `exp(value - 0)` underflows to 0 for all elements, giving
denominator=0 and NaN output.

Fixed by initializing `max_elements = vec4(-3.402823e+38)` (i.e.,
-FLT_MAX) so that workers with no valid texels contribute -inf to the
reduction. Also added a `safe_denominator = max(denominator, 1e-37)`
clamp as a secondary safety net against any remaining underflow edge
cases.

This affected the edgeTAM attention softmax over 12 key positions, where
~15% of query rows had all-negative attention scores and produced NaN.

**Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)**

Applied similar defensive fixes to `softmax_nonpacked_dim`:
- Clamped denominator via `max(denominators, vec4(1e-37))` to prevent
  0/0 = NaN if all exp values underflow.
- Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels.
  This uses `floatBitsToUint`/`uintBitsToFloat` with exponent-bit
  masking rather than `isnan()` or `x != x`, which may not work reliably
  on all GPU drivers due to OpIsNan bugs and ordered comparison
  semantics.
- Added `memoryBarrierImage()` after the output write loop to flush
  imageStore writes so they're visible to subsequent GPU operations.

**Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)**

The depthwise convolution code path in `add_conv2d_node` unconditionally
passed kernel parameters (stride, padding, dilation, etc.) via push
constants. However, the base `conv2d_dw.glsl` shader (used for non-3x3
and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these
parameters as UBOs at binding points 4–8, not as push constants. The
`_output_tile` shader variants do use push constants, so 3x3 and 5x5
depthwise convolutions worked correctly.

For 1x1 depthwise convolutions, the shader read from unbound UBOs,
getting zeros for stride, padding, dilation, and overlay_region. With
stride=0 and overlay_region=(0,0), the convolution loop never executed,
producing output equal to just the bias (effectively zero for small
biases).

Fixed by checking whether the selected shader name contains
`_output_tile`. If not, parameters are passed via UBOs (matching the
shader's declarations) instead of push constants.

**Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)**

The base `conv2d_dw.glsl` shader uses a fully 1D thread mapping where
`gl_GlobalInvocationID.x` encodes all three output dimensions:
`pos.x = gid.x % W`, `pos.y = (gid.x / W) % H`,
`pos.z = gid.x / (W * H)`. The `_output_tile` variants use a 2D mapping
with spatial tiles in `.x` and channels in `.y`.

The `conv2d_global_wg_size` callback was dispatching all depthwise
shaders with workgroup size `{W*H, C_packed, 1}`, which is correct for
`_output_tile` but wrong for the base shader. With this size, all
threads have `gid.x < W*H`, so `pos.z = gid.x / (W*H) = 0` — only
channel texel 0 (channels 0–3 out of e.g. 192) gets computed.

Fixed by dispatching `{W*H*C_packed, 1, 1}` for the base shader so
that `gid.x` ranges over all spatial × channel positions.

Differential Revision: [D95217947](https://our.internmc.facebook.com/intern/diff/D95217947/)

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17848

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 4 Unrelated Failures

As of commit b2e1541 with merge base 1a75394 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 4, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync meta-codesync Bot merged commit bd536d3 into gh/SS-JIA/457/base Mar 4, 2026
208 of 224 checks passed
@meta-codesync meta-codesync Bot deleted the gh/SS-JIA/457/head branch March 4, 2026 23:43
@meta-codesync meta-codesync Bot temporarily deployed to cherry-pick-bot March 4, 2026 23:43 Inactive
SS-JIA pushed a commit that referenced this pull request Mar 5, 2026
Fix three bugs causing incorrect output when running the edgeTAM model
with the Vulkan backend. Together these fixes bring the model from
producing all-NaN output to matching the reference within fp32 tolerance.

**Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)**

In `softmax_packed_dim`, each workgroup uses NWORKERS=4 threads to
collaboratively reduce along the packed dimension. Before the main loop,
each worker initializes `max_elements` by loading from texel index
`tid.x`. When NWORKERS exceeds the number of texels (e.g., a 12-element
dim has only 3 texels, but worker 3 tries to load texel index 3), the
load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the
cross-worker max reduction, so for any row where all actual values are
negative, the computed max becomes 0 instead of the true (negative) max.
Then `exp(value - 0)` underflows to 0 for all elements, giving
denominator=0 and NaN output.

Fixed by initializing `max_elements = vec4(-3.402823e+38)` (i.e.,
-FLT_MAX) so that workers with no valid texels contribute -inf to the
reduction. Also added a `safe_denominator = max(denominator, 1e-37)`
clamp as a secondary safety net against any remaining underflow edge
cases.

This affected the edgeTAM attention softmax over 12 key positions, where
~15% of query rows had all-negative attention scores and produced NaN.

**Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)**

Applied similar defensive fixes to `softmax_nonpacked_dim`:
- Clamped denominator via `max(denominators, vec4(1e-37))` to prevent
  0/0 = NaN if all exp values underflow.
- Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels.
  This uses `floatBitsToUint`/`uintBitsToFloat` with exponent-bit
  masking rather than `isnan()` or `x != x`, which may not work reliably
  on all GPU drivers due to OpIsNan bugs and ordered comparison
  semantics.
- Added `memoryBarrierImage()` after the output write loop to flush
  imageStore writes so they're visible to subsequent GPU operations.

**Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)**

The depthwise convolution code path in `add_conv2d_node` unconditionally
passed kernel parameters (stride, padding, dilation, etc.) via push
constants. However, the base `conv2d_dw.glsl` shader (used for non-3x3
and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these
parameters as UBOs at binding points 4–8, not as push constants. The
`_output_tile` shader variants do use push constants, so 3x3 and 5x5
depthwise convolutions worked correctly.

For 1x1 depthwise convolutions, the shader read from unbound UBOs,
getting zeros for stride, padding, dilation, and overlay_region. With
stride=0 and overlay_region=(0,0), the convolution loop never executed,
producing output equal to just the bias (effectively zero for small
biases).

Fixed by checking whether the selected shader name contains
`_output_tile`. If not, parameters are passed via UBOs (matching the
shader's declarations) instead of push constants.

**Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)**

The base `conv2d_dw.glsl` shader uses a fully 1D thread mapping where
`gl_GlobalInvocationID.x` encodes all three output dimensions:
`pos.x = gid.x % W`, `pos.y = (gid.x / W) % H`,
`pos.z = gid.x / (W * H)`. The `_output_tile` variants use a 2D mapping
with spatial tiles in `.x` and channels in `.y`.

The `conv2d_global_wg_size` callback was dispatching all depthwise
shaders with workgroup size `{W*H, C_packed, 1}`, which is correct for
`_output_tile` but wrong for the base shader. With this size, all
threads have `gid.x < W*H`, so `pos.z = gid.x / (W*H) = 0` — only
channel texel 0 (channels 0–3 out of e.g. 192) gets computed.

Fixed by dispatching `{W*H*C_packed, 1, 1}` for the base shader so
that `gid.x` ranges over all spatial × channel positions.

Differential Revision: [D95217947](https://our.internmc.facebook.com/intern/diff/D95217947/)

ghstack-source-id: 347411472
Pull Request resolved: #17848
jpiat pushed a commit to jpiat/executorch that referenced this pull request Mar 17, 2026
Fix three bugs causing incorrect output when running the edgeTAM model
with the Vulkan backend. Together these fixes bring the model from
producing all-NaN output to matching the reference within fp32 tolerance.

**Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)**

In `softmax_packed_dim`, each workgroup uses NWORKERS=4 threads to
collaboratively reduce along the packed dimension. Before the main loop,
each worker initializes `max_elements` by loading from texel index
`tid.x`. When NWORKERS exceeds the number of texels (e.g., a 12-element
dim has only 3 texels, but worker 3 tries to load texel index 3), the
load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the
cross-worker max reduction, so for any row where all actual values are
negative, the computed max becomes 0 instead of the true (negative) max.
Then `exp(value - 0)` underflows to 0 for all elements, giving
denominator=0 and NaN output.

Fixed by initializing `max_elements = vec4(-3.402823e+38)` (i.e.,
-FLT_MAX) so that workers with no valid texels contribute -inf to the
reduction. Also added a `safe_denominator = max(denominator, 1e-37)`
clamp as a secondary safety net against any remaining underflow edge
cases.

This affected the edgeTAM attention softmax over 12 key positions, where
~15% of query rows had all-negative attention scores and produced NaN.

**Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)**

Applied similar defensive fixes to `softmax_nonpacked_dim`:
- Clamped denominator via `max(denominators, vec4(1e-37))` to prevent
  0/0 = NaN if all exp values underflow.
- Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels.
  This uses `floatBitsToUint`/`uintBitsToFloat` with exponent-bit
  masking rather than `isnan()` or `x != x`, which may not work reliably
  on all GPU drivers due to OpIsNan bugs and ordered comparison
  semantics.
- Added `memoryBarrierImage()` after the output write loop to flush
  imageStore writes so they're visible to subsequent GPU operations.

**Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)**

The depthwise convolution code path in `add_conv2d_node` unconditionally
passed kernel parameters (stride, padding, dilation, etc.) via push
constants. However, the base `conv2d_dw.glsl` shader (used for non-3x3
and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these
parameters as UBOs at binding points 4–8, not as push constants. The
`_output_tile` shader variants do use push constants, so 3x3 and 5x5
depthwise convolutions worked correctly.

For 1x1 depthwise convolutions, the shader read from unbound UBOs,
getting zeros for stride, padding, dilation, and overlay_region. With
stride=0 and overlay_region=(0,0), the convolution loop never executed,
producing output equal to just the bias (effectively zero for small
biases).

Fixed by checking whether the selected shader name contains
`_output_tile`. If not, parameters are passed via UBOs (matching the
shader's declarations) instead of push constants.

**Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)**

The base `conv2d_dw.glsl` shader uses a fully 1D thread mapping where
`gl_GlobalInvocationID.x` encodes all three output dimensions:
`pos.x = gid.x % W`, `pos.y = (gid.x / W) % H`,
`pos.z = gid.x / (W * H)`. The `_output_tile` variants use a 2D mapping
with spatial tiles in `.x` and channels in `.y`.

The `conv2d_global_wg_size` callback was dispatching all depthwise
shaders with workgroup size `{W*H, C_packed, 1}`, which is correct for
`_output_tile` but wrong for the base shader. With this size, all
threads have `gid.x < W*H`, so `pos.z = gid.x / (W*H) = 0` — only
channel texel 0 (channels 0–3 out of e.g. 192) gets computed.

Fixed by dispatching `{W*H*C_packed, 1, 1}` for the base shader so
that `gid.x` ranges over all spatial × channel positions.

Differential Revision: [D95217947](https://our.internmc.facebook.com/intern/diff/D95217947/)

ghstack-source-id: 347411472
Pull Request resolved: pytorch#17848
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants