Skip to content

Commit 24da2f6

Browse files
ssjiaSS-JIA
authored andcommitted
[ez][ET-VK][partitioner] Allow layout-agnostic ops to accept quantized layouts
Pull Request resolved: #19395 Two changes that together let the partitioner keep PACKED_INT8 layouts flowing through identity-like ops, eliminating spurious clone dispatches: 1. utils.py: ANY_STORAGE_INCL_PACKED_INT8 (renamed from ALL_STORAGES_REPSET) previously claimed every layout (including PACKED_INT8_*) on the texture side, but PACKED_INT8 is buffer-only by convention — the texture indexing helpers and required_image_extents don't know about quantized layouts. Narrow the texture side to all_memory_layouts (float-only). Every existing call site is either an intersection identity or a wildcard for non-tensor / not-yet-prepacked args, so this narrow is non-breaking; and now the repset can act as a true universal set when intersected against quant-aware repsets. The new name slots cleanly next to ANY_STORAGE / ANY_BUFFER / ANY_TEXTURE and tells the reader exactly what is added: "like ANY_STORAGE, but also admits PACKED_INT8 (on the buffer side)". 2. op_registry.py: switch view_copy / clone / _clone_dim_order / alias_copy from inputs_storage=ANY_STORAGE to inputs_storage=ANY_STORAGE_INCL_PACKED_INT8. ANY_STORAGE is float-only, so when one of these no-op identity ops sits between two q8ta ops the BFS in TagMemoryMetaPass.constrain_op_*_repset short-circuits (zero overlap with PACKED_INT8_BUFFER) and forces transitions on both sides. With ANY_STORAGE_INCL_PACKED_INT8 they now admit both float and quantized layouts and the redundant-op transform folds them away. The 31 other ops using ANY_STORAGE are real compute ops (binaryop, comparison, softmax, argreduce, permute_copy, etc.) whose float-only kernels do not accept quantized int8x4 layouts (q8ta_* are separate ops); leaving those alone. On RefineNet 24feat (1x3x256x144) the 8 _clone_dim_order ops the partitioner had been inserting around the 4 fused q8ta_pixel_shuffle nodes are now folded by the delegate. Runtime q8ta_clone dispatches drop from 11 to 3 (the 3 residuals are unrelated, from the original model graph). ghstack-source-id: 379519734 @exported-using-ghexport Differential Revision: [D103770022](https://our.internmc.facebook.com/intern/diff/D103770022/)
1 parent 0cafcb2 commit 24da2f6

4 files changed

Lines changed: 24 additions & 17 deletions

File tree

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ def get_arg_tensor_source_repset(
252252
"""
253253
arg_node = op_node.args[arg_i]
254254

255-
# For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does
255+
# For non-tensor arguments, return ANY_STORAGE_INCL_PACKED_INT8 so that the respset does
256256
# not appear to be empty.
257257
if not utils.is_tensor_arg_node(arg_node):
258-
return utils.ALL_STORAGES_REPSET
258+
return utils.ANY_STORAGE_INCL_PACKED_INT8
259259

260260
# Special case for cat - use the first tensor in the list as representative
261261
if isinstance(arg_node, list):

backends/vulkan/op_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ def register_permute_copy():
11581158
@update_features(exir_ops.edge.aten.view_copy.default)
11591159
def register_view_copy():
11601160
return OpFeatures(
1161-
inputs_storage=utils.ANY_STORAGE,
1161+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
11621162
inputs_dtypes=utils.FP_INT_BOOL_T,
11631163
supports_resize=True,
11641164
supports_highdim=True,
@@ -1213,7 +1213,7 @@ def register_unsqueeze_copy():
12131213
@update_features(exir_ops.edge.aten.clone.default)
12141214
def register_clone():
12151215
return OpFeatures(
1216-
inputs_storage=utils.ANY_STORAGE,
1216+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12171217
inputs_dtypes=utils.FP_INT_BOOL_T,
12181218
supports_resize=True,
12191219
supports_highdim=True,
@@ -1223,7 +1223,7 @@ def register_clone():
12231223
@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
12241224
def register_clone_dim_order():
12251225
return OpFeatures(
1226-
inputs_storage=utils.ANY_STORAGE,
1226+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12271227
inputs_dtypes=utils.FP_INT_BOOL_T,
12281228
supports_resize=True,
12291229
supports_highdim=True,
@@ -1237,7 +1237,7 @@ def register_clone_dim_order():
12371237
@update_features(exir_ops.edge.aten.alias_copy.default)
12381238
def register_alias_copy():
12391239
return OpFeatures(
1240-
inputs_storage=utils.ANY_STORAGE,
1240+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12411241
inputs_dtypes=utils.FP_INT_BOOL_T,
12421242
supports_resize=True,
12431243
supports_highdim=True,

backends/vulkan/test/test_vulkan_tensor_repr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def test_no_sync_primary_io_when_different_repsets(self):
649649
# -- Scalar args are skipped --
650650

651651
def test_scalar_arg_skipped(self):
652-
"""Non-tensor args should be treated as ALL_STORAGES_REPSET."""
652+
"""Non-tensor args should be treated as ANY_STORAGE_INCL_PACKED_INT8."""
653653
tensor_arg = _make_tensor_arg_node((1, 3, 8, 8))
654654
# Second arg is a scalar (float)
655655
scalar_arg = 1.0
@@ -666,8 +666,8 @@ def test_scalar_arg_skipped(self):
666666
DEFAULT_TEXTURE_LIMITS,
667667
)
668668
self.assertFalse(op_repsets.any_is_empty())
669-
# The scalar arg should get ALL_STORAGES_REPSET
670-
# self.assertEqual(op_repsets.get_arg_repset(1), ALL_STORAGES_REPSET, f"""{op_repsets.get_arg_repset(1)}""")
669+
# The scalar arg should get ANY_STORAGE_INCL_PACKED_INT8
670+
# self.assertEqual(op_repsets.get_arg_repset(1), ANY_STORAGE_INCL_PACKED_INT8, f"""{op_repsets.get_arg_repset(1)}""")
671671

672672
# -- pick_representations --
673673

backends/vulkan/utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,8 +1203,15 @@ def filter_invalid_reprs_for_node_list(
12031203
# Special use RepSets
12041204

12051205
NO_STORAGE = TensorRepSet(set(), set())
1206-
ALL_STORAGES_REPSET = TensorRepSet(
1207-
universal_memory_layout_set, universal_memory_layout_set
1206+
# Buffer side admits both float and quantized (PACKED_INT8_*) layouts; texture side
1207+
# is float-only because the Vulkan backend has no quantized texture support
1208+
# (required_image_extents and the texture indexing helpers only know about the
1209+
# float layouts). Used as an intersection identity (e.g. common_arg_repset
1210+
# accumulator) and as a placeholder for non-tensor / not-yet-prepacked args, so
1211+
# narrowing the texture side is non-breaking for those uses while letting it act
1212+
# as a true universal set when intersected against quant-aware repsets.
1213+
ANY_STORAGE_INCL_PACKED_INT8 = TensorRepSet(
1214+
universal_memory_layout_set, all_memory_layouts
12081215
)
12091216

12101217

@@ -1330,19 +1337,19 @@ def __init__( # noqa: C901
13301337
# Now, go through the arguments of the operator and create a filtered repset
13311338
# for each based on the actual tensor value.
13321339
args_repset_list = TensorRepSetList([])
1333-
common_arg_repset = ALL_STORAGES_REPSET
1340+
common_arg_repset = ANY_STORAGE_INCL_PACKED_INT8
13341341
for i, arg_node in enumerate(op_node.args):
13351342
arg_repset = inputs_repsets[i]
13361343

1337-
# Use ALL_STORAGES_REPSET for non-tensor nodes so they don't cause the op
1344+
# Use ANY_STORAGE_INCL_PACKED_INT8 for non-tensor nodes so they don't cause the op
13381345
# repsets to appear empty
13391346
if not is_tensor_arg_node(arg_node):
1340-
args_repset_list.append(ALL_STORAGES_REPSET)
1347+
args_repset_list.append(ANY_STORAGE_INCL_PACKED_INT8)
13411348
# NO_STORAGE is used to denote that an input is either a non tensor arg or
13421349
# a weight tensor that is not prepacked. Similar to the above, use
1343-
# ALL_STORAGES_REPSET in this case.
1350+
# ANY_STORAGE_INCL_PACKED_INT8 in this case.
13441351
elif arg_repset.is_empty():
1345-
args_repset_list.append(ALL_STORAGES_REPSET)
1352+
args_repset_list.append(ANY_STORAGE_INCL_PACKED_INT8)
13461353
else:
13471354
assert not arg_repset.is_empty()
13481355

@@ -1355,7 +1362,7 @@ def __init__( # noqa: C901
13551362

13561363
# Repeat for output tensors.
13571364
outs_repset_list = TensorRepSetList([])
1358-
common_out_repset = ALL_STORAGES_REPSET
1365+
common_out_repset = ANY_STORAGE_INCL_PACKED_INT8
13591366
if num_tensors_in_node(op_node) == 1:
13601367
common_out_repset = filter_invalid_reprs(
13611368
op_node.meta["val"], outputs_repsets[0], texture_limits

0 commit comments

Comments
 (0)