Skip to content

Commit b7e567f

Browse files
author
ssjia
committed
[ET-VK][ops] Add eq.Scalar operator
Pull Request resolved: #20383 Adds Vulkan support for `aten.eq.Scalar`. This is the second of two ops needed to collapse the Llama4-mini TISO en_US backbone export to a single Vulkan partition (after `bitwise_or`): the discrete-speech mask compares the int token-id tensor against scalar constants via `aten.eq.Scalar`, which previously had no Vulkan implementation and forced a CPU fallback that split the delegated graph. Implemented by extending the existing tensor-scalar binary-op path with a comparison-output variant: `binary_scalar_buffer.glsl` / `binary_scalar_texture.glsl` gain an `IS_COMPARISON_OP` code path that writes a `uint8` (bool) output while leaving the existing arithmetic (e.g. `pow`) path unchanged; `binary_scalar_buffer.yaml` / `binary_scalar_texture.yaml` add an `eq_scalar` variant (half/float/int32 — the texture variant uses `equal(X, Y)` for per-lane `bvec4`, the buffer variant uses scalar `X == Y`); `BinaryScalarOp.cpp` adds an `eq_tensor_scalar` dispatch and `VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar)`; `op_registry.py` registers `aten.eq.Scalar` `OpFeatures` (FP/INT tensor input, bool output). The int64 token tensor is serialized to int32 via the existing `downcast_64_bit` path, so the dispatch resolves to the int32 shader variant; no dtype-conversion pass is added. This change was authored with Claude. ghstack-source-id: 396618180 @exported-using-ghexport Differential Revision: [D108457791](https://our.internmc.facebook.com/intern/diff/D108457791/)
1 parent 04eda6b commit b7e567f

7 files changed

Lines changed: 87 additions & 10 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,17 @@ def register_pow_tensor_scalar():
327327
)
328328

329329

330+
@update_features(exir_ops.edge.aten.eq.Scalar)
331+
def register_eq_scalar():
332+
return OpFeatures(
333+
inputs_storage=utils.ANY_STORAGE,
334+
inputs_dtypes=utils.FP_INT_T,
335+
outputs_dtypes=utils.BOOL_T,
336+
supports_resize=True,
337+
supports_highdim=True,
338+
)
339+
340+
330341
# =============================================================================
331342
# ToCopy.cpp
332343
# =============================================================================

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,24 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// Binary comparison ops require that the output is boolean and not the same as
10+
// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml.
11+
912
#version 450 core
1013

1114
${define_required_extensions(STORAGE, DTYPE)}
15+
$if IS_COMPARISON_OP:
16+
${define_required_extensions(STORAGE, "uint8")}
1217

1318
#define PRECISION ${PRECISION}
1419

1520
#define NAME ${VARIANT_NAME}
1621

1722
#define T ${buffer_scalar_type(DTYPE)}
23+
$if IS_COMPARISON_OP:
24+
#define OUT_T ${buffer_scalar_type("uint8")}
25+
$else:
26+
#define OUT_T ${buffer_scalar_type(DTYPE)}
1827

1928
#define op(X, Y) ${OPERATOR}
2029

@@ -24,7 +33,11 @@ layout(std430) buffer;
2433

2534
#include "indexing.glslh"
2635

27-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
36+
$if IS_COMPARISON_OP:
37+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
38+
$else:
39+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
40+
2841
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2942

3043
${layout_declare_ubo(B, "BufferMetadata", "outp")}
@@ -36,13 +49,14 @@ layout(push_constant) uniform restrict Block {
3649

3750
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3851

39-
#include "binary_op_defs.glslh"
52+
$if not IS_COMPARISON_OP:
53+
#include "binary_op_defs.glslh"
4054

4155
void main() {
4256
const uint out_bufi = gl_GlobalInvocationID.x;
4357
if (out_of_bounds(out_bufi, outp)) {
4458
return;
4559
}
4660

47-
t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value)));
61+
t_out[out_bufi] = OUT_T(op(t_in[out_bufi], T(scalar_value)));
4862
}

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
binary_scalar_buffer:
88
parameter_names_with_default_values:
99
OPERATOR: power_of(X, Y)
10-
NDIM: 3
10+
IS_COMPARISON_OP: false
1111
DTYPE: float
12-
PACKING: C_packed
1312
STORAGE: buffer
1413
generate_variant_forall:
1514
DTYPE:
@@ -18,3 +17,6 @@ binary_scalar_buffer:
1817
- VALUE: int32
1918
shader_variants:
2019
- NAME: pow_scalar_buffer
20+
- NAME: eq_scalar_buffer
21+
OPERATOR: X == Y
22+
IS_COMPARISON_OP: true

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,25 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// Binary comparison ops require that the output is boolean and not the same as
10+
// input. IS_COMPARISON_OP is set explicitly per shader variant in the .yaml.
11+
912
#version 450 core
1013

1114
${define_required_extensions(STORAGE, DTYPE)}
15+
$if IS_COMPARISON_OP:
16+
${define_required_extensions(STORAGE, "uint8")}
1217

1318
#define PRECISION ${PRECISION}
1419

1520
#define NAME ${VARIANT_NAME}
1621

1722
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1823
#define T ${texel_load_component_type(DTYPE, STORAGE)}
24+
$if IS_COMPARISON_OP:
25+
#define VEC4_OUT_T ${texel_load_type("uint8", STORAGE)}
26+
$else:
27+
#define VEC4_OUT_T VEC4_T
1928

2029
#define op(X, Y) ${OPERATOR}
2130

@@ -25,7 +34,11 @@ layout(std430) buffer;
2534

2635
#include "indexing.glslh"
2736

28-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
37+
$if IS_COMPARISON_OP:
38+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
39+
$else:
40+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
41+
2942
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
3043

3144
${layout_declare_ubo(B, "TextureMetadata", "outp")}
@@ -37,7 +50,8 @@ layout(push_constant) uniform restrict Block {
3750

3851
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3952

40-
#include "binary_op_defs.glslh"
53+
$if not IS_COMPARISON_OP:
54+
#include "binary_op_defs.glslh"
4155

4256
void main() {
4357
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -47,7 +61,7 @@ void main() {
4761
}
4862

4963
VEC4_T in_texel = texelFetch(t_in, pos, 0);
50-
VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value)));
64+
VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, VEC4_T(scalar_value)));
5165

5266
imageStore(t_out, pos, out_texel);
5367
}

backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
binary_scalar_texture:
88
parameter_names_with_default_values:
99
OPERATOR: power_of(X, Y)
10-
NDIM: 3
10+
IS_COMPARISON_OP: false
1111
DTYPE: float
12-
PACKING: C_packed
1312
STORAGE: texture3d
1413
generate_variant_forall:
1514
DTYPE:
@@ -18,3 +17,6 @@ binary_scalar_texture:
1817
- VALUE: int32
1918
shader_variants:
2019
- NAME: pow_scalar_texture3d
20+
- NAME: eq_scalar_texture3d
21+
OPERATOR: equal(X, Y)
22+
IS_COMPARISON_OP: true

backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,13 @@ void pow_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
7373
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow");
7474
}
7575

76+
void eq_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
77+
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "eq");
78+
}
79+
7680
REGISTER_OPERATORS {
7781
VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar);
82+
VK_REGISTER_OP(aten.eq.Scalar, eq_tensor_scalar);
7883
}
7984

8085
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,3 +2237,32 @@ def get_pow_tensor_scalar_inputs():
22372237
]
22382238
test_suite.dtypes = ["at::kFloat"]
22392239
return test_suite
2240+
2241+
2242+
@register_test_suite("aten.eq.Scalar")
2243+
def get_eq_scalar_inputs():
2244+
# Scalars are chosen to fall within the make_seq_tensor range (1..numel),
2245+
# so each case exercises a genuine mix of equal / not-equal elements rather
2246+
# than a trivially all-false comparison.
2247+
test_suite = VkTestSuite(
2248+
[
2249+
((M1,), 5),
2250+
((M2, M1), 100),
2251+
((S1, M1, M2), 1000),
2252+
((S1, S2, S2, M2), 2000),
2253+
((S, S1, S2), 50),
2254+
((M1, M2), 700),
2255+
((S1, S2), 20),
2256+
]
2257+
)
2258+
test_suite.storage_types = [
2259+
"utils::kBuffer",
2260+
"utils::kTexture3D",
2261+
]
2262+
test_suite.layouts = [
2263+
"utils::kWidthPacked",
2264+
"utils::kChannelsPacked",
2265+
]
2266+
test_suite.dtypes = ["at::kInt"]
2267+
test_suite.data_gen = "make_seq_tensor"
2268+
return test_suite

0 commit comments

Comments
 (0)