Skip to content

Commit d603bf6

Browse files
committed
temp changes to enable benchmarking on device via adb
Differential Revision: [D99677681](https://our.internmc.facebook.com/intern/diff/D99677681/) [ghstack-poisoned]
1 parent e05dcc6 commit d603bf6

3 files changed

Lines changed: 928 additions & 1 deletion

File tree

extension/llm/custom_ops/BUCK

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ fbcode_target(_kind = runtime.python_binary,
134134
],
135135
)
136136

137+
fbcode_target(_kind = runtime.python_binary,
138+
name = "run_bench_on_device",
139+
srcs = ["run_bench_on_device.py"],
140+
main_function = "executorch.extension.llm.custom_ops.run_bench_on_device.main",
141+
)
142+
137143
fbcode_target(_kind = cpp_benchmark,
138144
name = "bench_sdpa",
139145
srcs = ["bench_sdpa.cpp"],

extension/llm/custom_ops/bench_sdpa.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,14 @@ bool validate_config(
296296
fill_random(k, gen);
297297
fill_random(v, gen);
298298

299-
// Reference: ET custom_sdpa_out (10-param signature, standard layout)
299+
// Reference: ET custom_sdpa_out (standard [B,S,H,D] layout)
300300
Tensor out_ref = tf.zeros(
301301
{(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
302302
KernelRuntimeContext ctx{};
303303
torch::executor::native::custom_sdpa_out(
304304
ctx, q, k, v, start_pos,
305305
std::nullopt, 0.0, true, std::nullopt,
306+
false, false, false,
306307
out_ref);
307308

308309
// Test: GEMM-based standard SDPA
@@ -473,6 +474,9 @@ BENCHMARK_DEFINE_F(SDPABenchFixture, CustomSDPA)
473474
0.0, // dropout_p
474475
true, // is_causal
475476
std::nullopt, // scale
477+
false, // is_seq_dim_2
478+
false, // is_k_seq_dim_2
479+
false, // is_v_seq_dim_2
476480
*output_);
477481
}
478482
}

0 commit comments

Comments
 (0)