Skip to content

Commit d802a2c

Browse files
author
ssjia
committed
Update on "[ET-VK] Add symint infrastructure to VulkanBackend and ComputeGraph"
Extend the Vulkan backend runtime infrastructure to better support symbolic integer (symint) arguments. This is a prerequisite for operators that need to handle dynamic shapes via symint values. Changes: - VulkanBackend.cpp: Compute output offset from end of args instead of assuming outputs follow inputs directly. Add scalar-to-tensor input handling so that Int/Bool EValues can populate tensor inputs. Support symint inputs provided as raw Int EValues (not just scalar tensors). Add symint output handling to write values back as tensor or Int EValue. - ComputeGraph.h: Add SymInt case to extract_scalar<T>() so operators can transparently read symint values as scalars. - ComputeGraph.cpp: Add Int fallback in read_symint() so values stored as plain Int (rather than SymInt objects) can be read uniformly. Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/) cc manuelcandales digantdesai cbilgin [ghstack-poisoned]
2 parents 1575ecf + 2a12bf0 commit d802a2c

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def constrain_repset_with_user(
302302
current_node, arg_repset, search_depth
303303
)
304304

305-
def trace_node_users_to_constrain_repset(
305+
def trace_node_users_to_constrain_repset( # noqa: C901
306306
self,
307307
origin_node: torch.fx.Node,
308308
repset: utils.TensorRepSet,

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,15 +452,14 @@ ValueRef ComputeGraph::add_tensor(
452452
const utils::AxisMapLayout axis_map_layout) {
453453
ValueRef idx(static_cast<int>(values_.size()));
454454
check_no_active_value_ptrs();
455-
values_.emplace_back(
456-
api::vTensor(
457-
context(),
458-
sizes,
459-
dtype,
460-
storage_type,
461-
memory_layout,
462-
false,
463-
axis_map_layout));
455+
values_.emplace_back(api::vTensor(
456+
context(),
457+
sizes,
458+
dtype,
459+
storage_type,
460+
memory_layout,
461+
false,
462+
axis_map_layout));
464463

465464
if (shared_object_idx >= 0) {
466465
get_shared_object(shared_object_idx).add_user(this, idx);

0 commit comments

Comments
 (0)