Skip to content

Commit 14168a3

Browse files
[ExecuTorch][WebGPU] Add update_cache tests (native numeric + export)
Pull Request resolved: pytorch#20084 Tests for `llama.update_cache.default`, stacked on the op diff below. `test/ops/sdpa/test_update_cache.py` lowers the op through `VulkanPartitioner` (asserting it delegates to VulkanBackend) and exports per-case `.pte`s; `test/native/test_update_cache.cpp` runs them on-GPU and checks an integer-exact scatter golden against the returned cache. Coverage mirrors the Vulkan KV-cache test (`VulkanSDPATest`): single-shot writes at varied shapes/offsets, plus a multi-step advancing-input_pos replay that threads the returned cache across steps over the same GQA param sets (incl. llama3 head_dim=128). Comparing the cache directly is stronger than Vulkan, which checks it only indirectly via the SDPA output. Authored with assistance from Claude. ghstack-source-id: 391979582 @exported-using-ghexport Differential Revision: [D107547307](https://our.internmc.facebook.com/intern/diff/D107547307/)
1 parent 1208a2f commit 14168a3

5 files changed

Lines changed: 514 additions & 3 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,7 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
125125
add_webgpu_native_test(
126126
webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp
127127
)
128+
add_webgpu_native_test(
129+
webgpu_update_cache_test test/native/test_update_cache.cpp
130+
)
128131
endif()

backends/webgpu/scripts/test_webgpu_native_ci.sh

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
#
1919
# Builds whatever native test targets are present in the landed tree (NOT a fixed
2020
# list). This stack lands: webgpu_native_test, webgpu_rms_norm_test (base) +
21-
# webgpu_dispatch_order_test, webgpu_scratch_buffer_test (D107576199). update_cache
22-
# / SDPA executables join automatically once their sibling diffs land.
21+
# webgpu_dispatch_order_test, webgpu_scratch_buffer_test (D107576199) +
22+
# webgpu_update_cache_test (D107547307). SDPA executables join once they land.
2323

2424
set -e
2525

@@ -45,6 +45,8 @@ RMS_NORM_DIR="/tmp/rmsn"
4545
RMS_NORM_OK=1
4646
DISPATCH_ORDER_DIR="/tmp/dispatch_order"
4747
DISPATCH_ORDER_OK=1
48+
UPDATE_CACHE_DIR="/tmp/update_cache"
49+
UPDATE_CACHE_OK=1
4850

4951
$PYTHON_EXECUTABLE -c "
5052
from executorch.backends.webgpu.test.ops.add.test_add import export_add_model, export_chained_add_model
@@ -62,6 +64,17 @@ from executorch.backends.webgpu.test.ops.dispatch_order.test_dispatch_order impo
6264
export_dispatch_order_cases('${DISPATCH_ORDER_DIR}')
6365
" || { echo "WARN: dispatch_order export failed; skipping dispatch_order native test"; DISPATCH_ORDER_OK=0; }
6466

67+
$PYTHON_EXECUTABLE -c "
68+
from executorch.backends.webgpu.test.ops.sdpa.test_update_cache import (
69+
export_update_cache_cases,
70+
export_update_cache_replay,
71+
export_update_cache_negative,
72+
)
73+
export_update_cache_cases('${UPDATE_CACHE_DIR}')
74+
export_update_cache_replay('${UPDATE_CACHE_DIR}')
75+
export_update_cache_negative('${UPDATE_CACHE_DIR}')
76+
" || { echo "WARN: update_cache export failed; skipping update_cache native test"; UPDATE_CACHE_OK=0; }
77+
6578
# ── Configure (Dawn-only: no -DWEBGPU_IMPL; Dawn is the sole backend) ─────────
6679
echo "=== Configure WebGPU native tests on Dawn ==="
6780
rm -rf "${BUILD_DIR}"
@@ -79,7 +92,7 @@ cmake \
7992
"${EXECUTORCH_ROOT}"
8093

8194
# ── Build + run every native test target that exists in this tree ────────────
82-
TARGETS=(webgpu_native_test webgpu_rms_norm_test webgpu_dispatch_order_test webgpu_scratch_buffer_test)
95+
TARGETS=(webgpu_native_test webgpu_rms_norm_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test)
8396
BIN_DIR="${BUILD_DIR}/backends/webgpu"
8497

8598
# Which targets are defined depends on which diffs are landed (native_test +
@@ -122,6 +135,9 @@ fi
122135
if [[ "${RMS_NORM_OK}" == "1" && -x "${BIN_DIR}/webgpu_rms_norm_test" ]]; then
123136
"${BIN_DIR}/webgpu_rms_norm_test" "${RMS_NORM_DIR}"
124137
fi
138+
if [[ "${UPDATE_CACHE_OK}" == "1" && -x "${BIN_DIR}/webgpu_update_cache_test" ]]; then
139+
"${BIN_DIR}/webgpu_update_cache_test" "${UPDATE_CACHE_DIR}"
140+
fi
125141
if [[ "${DISPATCH_ORDER_OK}" == "1" && -x "${BIN_DIR}/webgpu_dispatch_order_test" ]]; then
126142
"${BIN_DIR}/webgpu_dispatch_order_test" "${DISPATCH_ORDER_DIR}"
127143
fi
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
10+
#include <executorch/extension/module/module.h>
11+
#include <executorch/extension/tensor/tensor.h>
12+
13+
#include <algorithm>
14+
#include <cmath>
15+
#include <cstdio>
16+
#include <cstdlib>
17+
#include <string>
18+
#include <vector>
19+
20+
using namespace executorch::backends::webgpu;
21+
using namespace executorch::extension;
22+
using namespace executorch::runtime;
23+
24+
namespace {
25+
26+
struct UpdateCacheCase {
27+
const char* name;
28+
int s;
29+
int h;
30+
int d;
31+
int cmax;
32+
int input_pos;
33+
};
34+
35+
// Mirrors test_update_cache.py CASES; golden scatter is integer-exact (inline).
36+
constexpr UpdateCacheCase kCases[] = {
37+
{"prefill", 2, 2, 4, 8, 0},
38+
{"offset", 2, 2, 4, 8, 5},
39+
{"shape_b", 3, 4, 8, 16, 0},
40+
{"shape_b_offset", 3, 4, 8, 16, 10},
41+
};
42+
43+
bool run_case(const std::string& dir, const UpdateCacheCase& tc) {
44+
printf(
45+
"\n--- Test: update_cache[%s] (S=%d,H=%d,D=%d,Cmax=%d,pos=%d) ---\n",
46+
tc.name,
47+
tc.s,
48+
tc.h,
49+
tc.d,
50+
tc.cmax,
51+
tc.input_pos);
52+
Module module(dir + "/" + tc.name + ".pte");
53+
if (module.load_forward() != Error::Ok) {
54+
printf("FAIL: could not load %s.pte\n", tc.name);
55+
return false;
56+
}
57+
58+
const int vnumel = tc.s * tc.h * tc.d;
59+
const int cnumel = tc.cmax * tc.h * tc.d;
60+
std::vector<float> value(vnumel);
61+
std::vector<float> cache(cnumel);
62+
for (int i = 0; i < vnumel; i++) {
63+
value[i] = static_cast<float>(i) * 0.5f;
64+
}
65+
for (int i = 0; i < cnumel; i++) {
66+
cache[i] = static_cast<float>(i) + 100.0f;
67+
}
68+
69+
// Inline reference: scatter value into the cache at input_pos, bounds-checked
70+
// exactly as the op (integer-exact copy, no library needed).
71+
std::vector<float> ref(cache);
72+
const int dst_offset = tc.input_pos * tc.h * tc.d;
73+
for (int i = 0; i < vnumel; i++) {
74+
if (dst_offset + i < cnumel) {
75+
ref[dst_offset + i] = value[i];
76+
}
77+
}
78+
79+
auto v = make_tensor_ptr({1, tc.s, tc.h, tc.d}, std::vector<float>(value));
80+
auto c = make_tensor_ptr({1, tc.cmax, tc.h, tc.d}, std::vector<float>(cache));
81+
auto result = module.forward({EValue(v), EValue(c)});
82+
if (!result.ok()) {
83+
printf("FAIL: forward failed (error %d)\n", (int)result.error());
84+
return false;
85+
}
86+
const auto& outputs = result.get();
87+
if (outputs.empty() || !outputs[0].isTensor()) {
88+
printf("FAIL: no tensor output\n");
89+
return false;
90+
}
91+
const auto& out_tensor = outputs[0].toTensor();
92+
if (static_cast<int>(out_tensor.numel()) != cnumel) {
93+
printf(
94+
"FAIL: output numel %zu != expected %d\n",
95+
(size_t)out_tensor.numel(),
96+
cnumel);
97+
return false;
98+
}
99+
const float* out_data = out_tensor.const_data_ptr<float>();
100+
101+
float max_abs_err = 0.0f;
102+
for (int i = 0; i < cnumel; i++) {
103+
max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i]));
104+
}
105+
printf("Max abs error: %e (checked %d elements)\n", max_abs_err, cnumel);
106+
// update_cache is a pure scatter copy: the output must be bit-exact.
107+
if (max_abs_err > 0.0f) {
108+
printf("FAIL: update_cache[%s] not bit-exact\n", tc.name);
109+
return false;
110+
}
111+
printf("PASS: update_cache[%s]\n", tc.name);
112+
return true;
113+
}
114+
115+
struct ReplayCase {
116+
const char* name;
117+
int h;
118+
int d;
119+
std::vector<int> seq_lens;
120+
};
121+
122+
// Multi-step advancing-input_pos cache accumulation, mirroring VulkanSDPATest.
123+
bool run_replay(const std::string& dir, const ReplayCase& rc) {
124+
int cmax = 0;
125+
for (int s : rc.seq_lens) {
126+
cmax += s;
127+
}
128+
printf(
129+
"\n--- Replay: update_cache[%s] (H=%d,D=%d,Cmax=%d,%zu steps) ---\n",
130+
rc.name,
131+
rc.h,
132+
rc.d,
133+
cmax,
134+
rc.seq_lens.size());
135+
136+
const int cnumel = cmax * rc.h * rc.d;
137+
std::vector<float> cache(cnumel);
138+
for (int i = 0; i < cnumel; i++) {
139+
cache[i] = static_cast<float>(i) + 100.0f;
140+
}
141+
std::vector<float> ref(cache);
142+
143+
int input_pos = 0;
144+
bool ok = true;
145+
for (size_t step = 0; step < rc.seq_lens.size(); step++) {
146+
const int s = rc.seq_lens[step];
147+
const int vnumel = s * rc.h * rc.d;
148+
std::vector<float> value(vnumel);
149+
const float base = static_cast<float>((input_pos + 1) * 1000);
150+
for (int i = 0; i < vnumel; i++) {
151+
value[i] = (base + static_cast<float>(i)) * 0.25f;
152+
}
153+
154+
const std::string fname = dir + "/" + rc.name + "_step" +
155+
std::to_string(step) + "_S" + std::to_string(s) + "_pos" +
156+
std::to_string(input_pos) + ".pte";
157+
Module module(fname);
158+
if (module.load_forward() != Error::Ok) {
159+
printf("FAIL: could not load %s\n", fname.c_str());
160+
return false;
161+
}
162+
163+
auto v = make_tensor_ptr({1, s, rc.h, rc.d}, std::vector<float>(value));
164+
auto c = make_tensor_ptr({1, cmax, rc.h, rc.d}, std::vector<float>(cache));
165+
auto result = module.forward({EValue(v), EValue(c)});
166+
if (!result.ok()) {
167+
printf(
168+
"FAIL: forward failed step %zu (error %d)\n",
169+
step,
170+
(int)result.error());
171+
return false;
172+
}
173+
const auto& outputs = result.get();
174+
if (outputs.empty() || !outputs[0].isTensor() ||
175+
static_cast<int>(outputs[0].toTensor().numel()) != cnumel) {
176+
printf("FAIL: bad cache output at step %zu\n", step);
177+
return false;
178+
}
179+
const float* out_data = outputs[0].toTensor().const_data_ptr<float>();
180+
181+
const int dst_offset = input_pos * rc.h * rc.d;
182+
for (int i = 0; i < vnumel; i++) {
183+
if (dst_offset + i < cnumel) {
184+
ref[dst_offset + i] = value[i];
185+
}
186+
}
187+
188+
float max_abs_err = 0.0f;
189+
for (int i = 0; i < cnumel; i++) {
190+
max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i]));
191+
cache[i] = out_data[i]; // thread the accumulated cache into the next step
192+
}
193+
printf(
194+
" step %zu (S=%d,pos=%d): max abs error %e\n",
195+
step,
196+
s,
197+
input_pos,
198+
max_abs_err);
199+
if (max_abs_err > 0.0f) { // pure scatter copy: must be bit-exact
200+
ok = false;
201+
}
202+
input_pos += s;
203+
}
204+
205+
if (ok) {
206+
printf("PASS: update_cache[%s] replay\n", rc.name);
207+
} else {
208+
printf("FAIL: update_cache[%s] replay\n", rc.name);
209+
}
210+
return ok;
211+
}
212+
213+
struct NegativeCase {
214+
const char* name;
215+
const char* guard;
216+
};
217+
218+
// Single-op, single-guard-violation cases: rejection maps to the named guard.
219+
bool run_negative_case(const std::string& dir, const NegativeCase& nc) {
220+
printf(
221+
"\n--- Negative: update_cache[%s] (expect rejection: %s) ---\n",
222+
nc.name,
223+
nc.guard);
224+
Module module(dir + "/" + nc.name + ".pte");
225+
const Error err = module.load_forward();
226+
// init catches the guard throw -> this code; other errors = setup failure.
227+
if (err != Error::DelegateInvalidCompatibility) {
228+
printf(
229+
"FAIL: %s.pte -> error %d; expected DelegateInvalidCompatibility "
230+
"from the '%s' guard\n",
231+
nc.name,
232+
(int)err,
233+
nc.guard);
234+
return false;
235+
}
236+
printf("PASS: rejected with DelegateInvalidCompatibility (%s)\n", nc.guard);
237+
return true;
238+
}
239+
240+
} // namespace
241+
242+
int main(int argc, char** argv) {
243+
std::string dir = "/tmp/update_cache";
244+
if (argc > 1) {
245+
dir = argv[1];
246+
}
247+
if (const char* env = std::getenv("WEBGPU_UPDATE_CACHE_DIR")) {
248+
dir = env;
249+
}
250+
251+
WebGPUContext ctx;
252+
try {
253+
ctx = create_webgpu_context();
254+
} catch (const std::exception& e) {
255+
printf("SKIP: %s\n", e.what());
256+
return 0;
257+
}
258+
set_default_webgpu_context(&ctx);
259+
printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str());
260+
261+
bool ok = true;
262+
for (const auto& tc : kCases) {
263+
ok = run_case(dir, tc) && ok;
264+
}
265+
266+
const std::vector<ReplayCase> kReplays = {
267+
{"seqA", 4, 4, {3, 1, 1, 5, 1, 1, 2}},
268+
{"seqB", 2, 8, {3, 1, 1, 5, 1, 1}},
269+
{"llama3", 8, 128, {111, 1, 1, 1, 57, 1, 1}},
270+
};
271+
for (const auto& rc : kReplays) {
272+
ok = run_replay(dir, rc) && ok;
273+
}
274+
275+
const NegativeCase kNegatives[] = {
276+
{"neg_batch", "batch must be 1"},
277+
{"neg_fp16", "fp32-only"},
278+
};
279+
for (const auto& nc : kNegatives) {
280+
ok = run_negative_case(dir, nc) && ok;
281+
}
282+
283+
set_default_webgpu_context(nullptr);
284+
destroy_webgpu_context(ctx);
285+
286+
if (!ok) {
287+
return 1;
288+
}
289+
printf("\nAll update_cache tests passed\n");
290+
return 0;
291+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)