Skip to content

Commit 0cce3c6

Browse files
author
ssjia
committed
Update on "[ET-VK] Add symint infrastructure to VulkanBackend and ComputeGraph"
Extend the Vulkan backend runtime infrastructure to better support symbolic integer (symint) arguments. This is a prerequisite for operators that need to handle dynamic shapes via symint values. Changes: - VulkanBackend.cpp: Compute output offset from end of args instead of assuming outputs follow inputs directly. Add scalar-to-tensor input handling so that Int/Bool EValues can populate tensor inputs. Support symint inputs provided as raw Int EValues (not just scalar tensors). Add symint output handling to write values back as tensor or Int EValue. - ComputeGraph.h: Add SymInt case to extract_scalar<T>() so operators can transparently read symint values as scalars. - ComputeGraph.cpp: Add Int fallback in read_symint() so values stored as plain Int (rather than SymInt objects) can be read uniformly. Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/) cc manuelcandales digantdesai cbilgin [ghstack-poisoned]
2 parents f019b92 + 3278cf7 commit 0cce3c6

11 files changed

Lines changed: 292 additions & 72 deletions

File tree

backends/arm/test/models/test_w2l_arm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424

2525
input_t = Tuple[torch.Tensor] # Input x
2626

27+
quant_test_data = {
28+
"per_channel_quantization=true": True,
29+
"per_channel_quantization=false": False,
30+
}
31+
2732

2833
def get_test_inputs(batch_size, num_features, input_frames):
2934
return (torch.randn(batch_size, num_features, input_frames),)
@@ -99,14 +104,15 @@ def test_w2l_u55_INT():
99104

100105
@pytest.mark.slow
101106
@common.XfailIfNoCorstone320
102-
@pytest.mark.skip(reason="Intermittent timeout issue: MLETORCH-856")
103-
def test_w2l_u85_INT():
107+
@common.parametrize("per_channel_quantization", quant_test_data)
108+
def test_w2l_u85_INT(per_channel_quantization):
104109
pipeline = EthosU85PipelineINT[input_t](
105-
TestW2L.create_model(),
110+
TestW2L.create_model("power_spectrum"),
106111
TestW2L.model_example_inputs,
107112
aten_ops=[],
108113
exir_ops=[],
109114
use_to_edge_transform_and_lower=True,
115+
per_channel_quantization=per_channel_quantization,
110116
)
111117
pipeline.run()
112118

backends/arm/test/ops/test_ceil.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_ceil_tosa_INT(test_data: input_t1):
8787
module.exir_op,
8888
atol=0.06,
8989
rtol=0.01,
90+
frobenius_threshold=0.2,
9091
)
9192
pipeline.run()
9293

backends/arm/test/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
addopts = --strict-markers
33
markers =
44
slow: Tests that take long time
5+
flaky: Tests that are known to be flaky/intermittent
56
tosa_ref_model: Tests that use TOSA reference model # Temporary!

backends/vulkan/runtime/vk_api/memory/Allocator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
namespace vkcompute {
1212
namespace vkapi {
1313

14-
bool test_host_cached_available(VkPhysicalDevice physical_device) {
14+
VmaAllocationCreateFlags test_host_cached_available(
15+
VkPhysicalDevice physical_device) {
1516
VkPhysicalDeviceMemoryProperties mem_props;
1617
vkGetPhysicalDeviceMemoryProperties(physical_device, &mem_props);
1718

0 commit comments

Comments
 (0)