Skip to content

Manual Sync ET-VK Diffs#19167

Merged
SS-JIA merged 3 commits intomainfrom
ssj_manual_sync
Apr 27, 2026
Merged

Manual Sync ET-VK Diffs#19167
SS-JIA merged 3 commits intomainfrom
ssj_manual_sync

Conversation

@SS-JIA
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA commented Apr 27, 2026

Manually sync the PR stack #19115 (landed internally in phabricator)

ssjia added 3 commits April 27, 2026 13:33
…itions

Pull Request resolved: #19113

Two changes to the layout assignment pass that together reduce layout transitions by ~89% for transformer-style models (73 → 9 for EdgeTAM ViT-S encoder):

1. BFS instead of DFS for downstream user tracing. The old DFS could exhaust the search budget (64 nodes) on one deep branch before discovering a constraining op on a sibling branch. BFS explores all immediate users at each level first, finding nearby layout-constrained ops (e.g. linear requiring width_packed) more reliably.

2. Prefer downstream consumers' layout over upstream source's layout. Previously, if the upstream source already had a representation (e.g. channels_packed from conv2d), that was applied first and locked in the layout via sync_primary_io_repr before downstream tracing could run. Now, downstream users are traced first to discover what layout they prefer, and the upstream source is only used as a fallback when downstream doesn't constrain.

For ViT-style transformers, conv2d (patch embedding) forces channels_packed, which previously propagated through all residual connections via flexible ops (layer_norm, add, mul). With downstream-preferred layout, linear ops' width_packed requirement is discovered first, so the entire transformer stack stays width_packed. Transitions only occur at the conv2d↔transformer boundaries.
ghstack-source-id: 373258238
@exported-using-ghexport

Differential Revision: [D102360203](https://our.internmc.facebook.com/intern/diff/D102360203/)
Pull Request resolved: #19114

This diff extends the ET-VK fused SDPA operator so it can be used for the ViT attention blocks in the EdgeTAM ViT-S encoder. The main correctness problem is that Q@K^T dot products in ViT attention can exceed the fp16 max (65504), so fp32 accumulation is required.

**fp16 overflow fix**: The intermediate `attn_weights` buffer is now always fp32 regardless of input dtype. Previously the QK shader accumulated in fp32 but stored to an fp16 buffer, causing overflow. The softmax shader reads fp32 attention weights and writes fp16 softmax output (safe since values are in [0, 1]).

**Texture support**: The QK and AV shaders support both buffer and texture3d storage for Q/K/V/output. The intermediate `attn_weights` and `attn_weights_softmax` tensors now inherit the storage type of the input/output (q_projected for the LLM path, out for the fused path), so the entire fused SDPA pipeline runs in a uniform storage type and no SDPA-internal layout transitions are needed.
ghstack-source-id: 373258239
@exported-using-ghexport

Differential Revision: [D102360200](https://our.internmc.facebook.com/intern/diff/D102360200/)
Pull Request resolved: #19115

Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.
ghstack-source-id: 373258231
@exported-using-ghexport

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19167

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@SS-JIA SS-JIA merged commit f61350b into main Apr 27, 2026
155 of 169 checks passed
@SS-JIA SS-JIA deleted the ssj_manual_sync branch April 27, 2026 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants