Skip to content

Commit 4f386fa

Browse files
committed
Update on "[Executorch] Add non-flash SDPA for decode"
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) [ghstack-poisoned]
2 parents 611b5fa + c762716 commit 4f386fa

31 files changed

Lines changed: 2760 additions & 46 deletions

.ci/scripts/test_model_e2e.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ EOF
354354
fi
355355
;;
356356
qwen3_5_moe)
357-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 32"
357+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
358358
;;
359359
voxtral_realtime)
360360
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
name: Test Cadence
2+
3+
permissions:
4+
id-token: write
5+
contents: read
6+
7+
on:
8+
workflow_call:
9+
inputs:
10+
docker-image:
11+
description: 'Docker image to use'
12+
required: false
13+
type: string
14+
default: ci-image:executorch-ubuntu-22.04-clang12
15+
runner:
16+
description: 'Runner type'
17+
required: false
18+
type: string
19+
default: linux.8xlarge.memory
20+
ref:
21+
description: 'Git ref to checkout'
22+
required: false
23+
type: string
24+
default: ${{ github.sha }}
25+
timeout:
26+
description: 'Job timeout in minutes'
27+
required: false
28+
type: number
29+
default: 90
30+
31+
jobs:
32+
test-aot:
33+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
34+
with:
35+
job-name: test-aot
36+
runner: ${{ inputs.runner }}
37+
docker-image: ${{ inputs.docker-image }}
38+
submodules: recursive
39+
ref: ${{ inputs.ref }}
40+
timeout: ${{ inputs.timeout }}
41+
script: |
42+
set -eux
43+
conda create -y -n cadence_test python=3.12 > /dev/null
44+
conda activate cadence_test
45+
46+
./install_requirements.sh > /dev/null
47+
pip install -e . --no-build-isolation > /dev/null
48+
pip install beartype later pyre_extensions pytest-xdist
49+
50+
python -m pytest backends/cadence/aot/tests/ -v -n auto
51+
52+
test-ops:
53+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
54+
with:
55+
job-name: test-ops
56+
runner: ${{ inputs.runner }}
57+
docker-image: ${{ inputs.docker-image }}
58+
submodules: recursive
59+
ref: ${{ inputs.ref }}
60+
timeout: ${{ inputs.timeout }}
61+
download-artifact: cadence-runner-build
62+
script: |
63+
set -eux
64+
conda create -y -n cadence_test python=3.12 > /dev/null
65+
conda activate cadence_test
66+
67+
./install_requirements.sh > /dev/null
68+
pip install -e . --no-build-isolation > /dev/null
69+
pip install beartype later pyre_extensions pytest-xdist
70+
71+
# Use the pre-built runner from the build job
72+
mkdir -p cmake-out/backends/cadence
73+
cp "${RUNNER_ARTIFACT_DIR}/cadence_runner" cmake-out/backends/cadence/cadence_runner
74+
chmod +x cmake-out/backends/cadence/cadence_runner
75+
76+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/backends/cadence/utils/FACTO"
77+
python -m pytest examples/cadence/operators/ -v -n auto

.github/workflows/build-cadence-runner.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Build Cadence
1+
name: Cadence Build & Test
22

33
on:
44
pull_request:
@@ -13,7 +13,7 @@ concurrency:
1313
cancel-in-progress: true
1414

1515
jobs:
16-
cpu:
16+
cpu-build:
1717
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1818
permissions:
1919
id-token: write
@@ -25,6 +25,7 @@ jobs:
2525
submodules: recursive
2626
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2727
timeout: 90
28+
upload-artifact: cadence-runner-build
2829
script: |
2930
set -eux
3031
# The generic Linux job chooses to use base env, not the one setup by the image
@@ -33,3 +34,15 @@ jobs:
3334
3435
./install_requirements.sh > /dev/null
3536
bash backends/cadence/build_cadence_runner.sh
37+
38+
# Copy runner binary to artifact dir for downstream test jobs
39+
cp cmake-out/backends/cadence/cadence_runner "${RUNNER_ARTIFACT_DIR}/"
40+
41+
cpu-test:
42+
needs: cpu-build
43+
permissions:
44+
id-token: write
45+
contents: read
46+
uses: ./.github/workflows/_test_cadence.yml
47+
with:
48+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}

.github/workflows/cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ jobs:
145145
# Run CUDA backend Python tests
146146
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
147147
148-
# Run quantize roundtrip tests (Qwen 3.5 MoE save/load prequantized)
149-
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py -v -o "addopts="
148+
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
149+
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
150150
151151
export-model-cuda-artifact:
152152
name: export-model-cuda-artifact

backends/aoti/common_shims_slim.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ int32_t aoti_torch_dtype_int8() {
134134
return 1; // ScalarType::Char
135135
}
136136

137+
int32_t aoti_torch_dtype_uint8() {
138+
return 0; // ScalarType::Byte
139+
}
140+
137141
int32_t aoti_torch_dtype_bool() {
138142
return 11; // ScalarType::Bool
139143
}

backends/aoti/common_shims_slim.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
7676
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
7777
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
7878
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
79+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_uint8();
7980
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
8081

8182
// ============================================================

backends/aoti/slim/c10/core/ScalarType.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using BFloat16 = ::executorch::runtime::etensor::BFloat16;
2323
/// Enum representing the scalar type (dtype) of tensor elements.
2424
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
2525
enum class ScalarType : int8_t {
26-
// Byte = 0, // uint8_t - not currently needed
26+
Byte = 0, // uint8_t
2727
Char = 1, // int8_t
2828
Short = 2, // int16_t
2929
Int = 3, // int32_t
@@ -43,6 +43,7 @@ enum class ScalarType : int8_t {
4343
};
4444

4545
// Type alias constants for convenience
46+
constexpr ScalarType kByte = ScalarType::Byte;
4647
constexpr ScalarType kChar = ScalarType::Char;
4748
constexpr ScalarType kShort = ScalarType::Short;
4849
constexpr ScalarType kInt = ScalarType::Int;
@@ -56,6 +57,8 @@ constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
5657
/// @return The size in bytes of a single element.
5758
inline size_t elementSize(ScalarType t) {
5859
switch (t) {
60+
case ScalarType::Byte:
61+
return sizeof(uint8_t);
5962
case ScalarType::Char:
6063
return sizeof(int8_t);
6164
case ScalarType::Short:
@@ -80,6 +83,8 @@ inline size_t elementSize(ScalarType t) {
8083
/// @return The name of the scalar type.
8184
inline const char* toString(ScalarType t) {
8285
switch (t) {
86+
case ScalarType::Byte:
87+
return "Byte";
8388
case ScalarType::Char:
8489
return "Char";
8590
case ScalarType::Short:
@@ -114,6 +119,7 @@ inline bool isFloatingType(ScalarType t) {
114119
/// @return true if the scalar type is integral, false otherwise.
115120
inline bool isIntegralType(ScalarType t, bool includeBool) {
116121
switch (t) {
122+
case ScalarType::Byte:
117123
case ScalarType::Char:
118124
case ScalarType::Short:
119125
case ScalarType::Int:
@@ -138,6 +144,7 @@ inline bool isBoolType(ScalarType t) {
138144
/// @return true if the scalar type is valid, false otherwise.
139145
inline bool isValidScalarType(ScalarType t) {
140146
switch (t) {
147+
case ScalarType::Byte:
141148
case ScalarType::Char:
142149
case ScalarType::Short:
143150
case ScalarType::Int:

backends/arm/runtime/VGFBackend.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,14 @@ void vkml_free_basics(
8888

8989
class VGFBackend final : public ::executorch::runtime::BackendInterface {
9090
public:
91-
VGFBackend() {
91+
VGFBackend() = default;
92+
93+
// Lazy Vulkan init — runs on first use, not in the constructor.
94+
void ensure_initialized() {
95+
if (is_initialized_) {
96+
return;
97+
}
98+
9299
VkResult result;
93100

94101
// Fetch basic vulkan objects once
@@ -122,6 +129,7 @@ class VGFBackend final : public ::executorch::runtime::BackendInterface {
122129

123130
bool is_available() const override {
124131
ET_LOG(Info, "Checking VGFBackend is available");
132+
const_cast<VGFBackend*>(this)->ensure_initialized();
125133
if (!is_initialized_) {
126134
return false;
127135
}
@@ -134,6 +142,7 @@ class VGFBackend final : public ::executorch::runtime::BackendInterface {
134142
ArrayRef<CompileSpec> compile_specs) const override {
135143
ET_LOG(Info, "Entered VGF init");
136144

145+
const_cast<VGFBackend*>(this)->ensure_initialized();
137146
if (!is_initialized_) {
138147
ET_LOG(
139148
Error,
@@ -334,6 +343,16 @@ VkResult vkml_allocate_basics(
334343
}
335344
volkLoadInstance(*instance);
336345

346+
// Bail out if the driver lacks ARM tensor/datagraph extensions.
347+
if (!vkCreateTensorARM) {
348+
ET_LOG(
349+
Error,
350+
"Vulkan driver does not support ARM tensor extensions (VK_ARM_tensors)");
351+
vkDestroyInstance(*instance, nullptr);
352+
*instance = VK_NULL_HANDLE;
353+
return VK_ERROR_FEATURE_NOT_PRESENT;
354+
}
355+
337356
// Pick first GPU
338357
uint32_t gpu_count = 0;
339358
vkEnumeratePhysicalDevices(*instance, &gpu_count, nullptr);

backends/arm/runtime/VGFSetup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
538538
.pNext = nullptr,
539539
.flags = 0,
540540
.bindingCount = static_cast<uint32_t>(layout_bindings.size()),
541-
layout_bindings.data(),
541+
.pBindings = layout_bindings.data(),
542542
};
543543
result =
544544
vkCreateDescriptorSetLayout(vk_device, &layout_info, nullptr, &vk_layout);

backends/arm/runtime/targets.bzl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,35 @@ def define_common_targets():
3232
"fbsource//third-party/ethos-u-core-driver:core_driver",
3333
],
3434
)
35+
runtime.cxx_library(
36+
name = "vgf_backend",
37+
srcs = [
38+
"VGFBackend.cpp",
39+
"VGFSetup.cpp",
40+
# Volk must be compiled directly into this target so its global
41+
# function-pointer variables live in the same linkage unit.
42+
# Linking from a separate static library causes the linker to
43+
# drop the symbols when building a shared library.
44+
"fbsource//third-party/vulkan-headers-1.4.343/v1.4.343/src:volk_arm_src",
45+
],
46+
exported_headers = ["VGFSetup.h"],
47+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
48+
link_whole = True,
49+
supports_python_dlopen = True,
50+
compiler_flags = [
51+
"-Wno-global-constructors",
52+
"-Wno-header-hygiene",
53+
"-Wno-unused-variable",
54+
"-Wno-missing-field-initializers",
55+
"-DUSE_VULKAN_WRAPPER",
56+
"-DUSE_VULKAN_VOLK",
57+
],
58+
visibility = ["PUBLIC"],
59+
deps = [
60+
"//executorch/runtime/backend:interface",
61+
"//executorch/runtime/core:core",
62+
"fbsource//third-party/arm-vgf-library/v0.8.0/src:vgf",
63+
"fbsource//third-party/vulkan-headers-1.4.343/v1.4.343/src:volk_arm",
64+
"fbsource//third-party/vulkan-headers-1.4.343/v1.4.343/src:vulkan-headers",
65+
],
66+
)

0 commit comments

Comments
 (0)