Skip to content

Commit 3070b7a

Browse files
author
ssjia
committed
[ET-VK] Add symint infrastructure to VulkanBackend and ComputeGraph
Extend the Vulkan backend runtime infrastructure to better support symbolic integer (symint) arguments. This is a prerequisite for operators that need to handle dynamic shapes via symint values. Changes: - VulkanBackend.cpp: Compute output offset from end of args instead of assuming outputs follow inputs directly. Add scalar-to-tensor input handling so that Int/Bool EValues can populate tensor inputs. Support symint inputs provided as raw Int EValues (not just scalar tensors). Add symint output handling to write values back as tensor or Int EValue. - ComputeGraph.h: Add SymInt case to extract_scalar<T>() so operators can transparently read symint values as scalars. - ComputeGraph.cpp: Add Int fallback in read_symint() so values stored as plain Int (rather than SymInt objects) can be read uniformly. Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/) [ghstack-poisoned]
1 parent 615d52f commit 3070b7a

3 files changed

Lines changed: 77 additions & 27 deletions

File tree

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
671671
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
672672

673673
const size_t num_inputs = compute_graph->inputs().size();
674+
const size_t num_outputs = compute_graph->outputs().size();
674675
bool should_propagate_resize = false;
675676
#ifdef ET_EVENT_TRACER_ENABLED
676677
runtime::EventTracer* event_tracer = context.event_tracer();
@@ -690,22 +691,51 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
690691
for (size_t i = 0; i < num_inputs; i++) {
691692
const ValueRef iref = compute_graph->inputs()[i].value;
692693
if (compute_graph->val_is_tensor(iref)) {
693-
VK_CHECK_COND(args[i]->isTensor());
694-
bool was_resized =
695-
maybe_resize_input(compute_graph, i, args[i]->toTensor());
696-
should_propagate_resize = should_propagate_resize || was_resized;
697-
compute_graph->maybe_cast_and_copy_into_staging(
698-
compute_graph->inputs()[i].staging,
699-
args[i]->toTensor().const_data_ptr(),
700-
args[i]->toTensor().numel(),
701-
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
694+
if (args[i]->isTensor()) {
695+
bool was_resized =
696+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
697+
should_propagate_resize = should_propagate_resize || was_resized;
698+
compute_graph->maybe_cast_and_copy_into_staging(
699+
compute_graph->inputs()[i].staging,
700+
args[i]->toTensor().const_data_ptr(),
701+
args[i]->toTensor().numel(),
702+
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
703+
} else if (args[i]->isInt() || args[i]->isBool()) {
704+
int64_t val =
705+
args[i]->isInt() ? args[i]->toInt() : (args[i]->toBool() ? 1 : 0);
706+
vkapi::ScalarType tensor_dtype = compute_graph->dtype_of(iref);
707+
if (tensor_dtype == vkapi::kFloat) {
708+
float fval = static_cast<float>(val);
709+
compute_graph->maybe_cast_and_copy_into_staging(
710+
compute_graph->inputs()[i].staging, &fval, 1, vkapi::kFloat);
711+
} else if (tensor_dtype == vkapi::kInt) {
712+
int32_t ival = static_cast<int32_t>(val);
713+
compute_graph->maybe_cast_and_copy_into_staging(
714+
compute_graph->inputs()[i].staging, &ival, 1, vkapi::kInt);
715+
} else {
716+
compute_graph->maybe_cast_and_copy_into_staging(
717+
compute_graph->inputs()[i].staging, &val, 1, vkapi::kLong);
718+
}
719+
} else {
720+
VK_THROW(
721+
"Tensor input[",
722+
i,
723+
"] has unsupported EValue tag ",
724+
static_cast<int>(args[i]->tag));
725+
}
702726
} else if (compute_graph->val_is_symint(iref)) {
703-
VK_CHECK_COND(
704-
args[i]->isTensor(),
705-
"Cannot handle symint arg to graph that is not derived from a "
706-
"scalar tensor at the moment.");
707-
bool was_updated = maybe_update_scalar_tensor(
708-
compute_graph, iref, args[i]->toTensor());
727+
bool was_updated = false;
728+
if (args[i]->isTensor()) {
729+
was_updated = maybe_update_scalar_tensor(
730+
compute_graph, iref, args[i]->toTensor());
731+
} else if (args[i]->isInt()) {
732+
const int32_t new_val = static_cast<int32_t>(args[i]->toInt());
733+
const int32_t cur_val = compute_graph->read_symint(iref);
734+
if (new_val != cur_val) {
735+
compute_graph->set_symint(iref, new_val);
736+
was_updated = true;
737+
}
738+
}
709739
// Since symint inputs may impact tensor's sizes, trigger a resize if
710740
// any symbolic integer shapes are updated.
711741
should_propagate_resize = should_propagate_resize || was_updated;
@@ -770,14 +800,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
770800
"ETVK_COPY_OUTPUTS",
771801
/* delegate_debug_id = */ -1);
772802
#endif // ET_EVENT_TRACER_ENABLED
773-
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
774-
const size_t o = i + num_inputs;
803+
const size_t output_offset = args.size() - num_outputs;
804+
for (size_t i = 0; i < num_outputs; i++) {
805+
const size_t o = output_offset + i;
775806
const ValueRef oref = compute_graph->outputs()[i].value;
776807
if (compute_graph->val_is_tensor(oref)) {
777808
VK_CHECK_COND(args[o]->isTensor());
778809
maybe_resize_output(compute_graph, i, args[o]->toTensor());
779-
// args holds inputs directly followed by outputs, so the i'th output
780-
// for compute_graph corresponds to the o'th arg
781810
compute_graph->maybe_cast_and_copy_from_staging(
782811
compute_graph->outputs()[i].staging,
783812
args[o]->toTensor().mutable_data_ptr(),
@@ -789,6 +818,20 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
789818
// returned as an output, no action is required.
790819
else if (compute_graph->val_is_tref(oref)) {
791820
continue;
821+
} else if (compute_graph->val_is_symint(oref)) {
822+
const int32_t symint_val = compute_graph->read_symint(oref);
823+
if (args[o]->isTensor()) {
824+
executorch::aten::Tensor& out_tensor = args[o]->toTensor();
825+
executorch::aten::ScalarType dtype = out_tensor.scalar_type();
826+
if (dtype == executorch::aten::ScalarType::Int) {
827+
*out_tensor.mutable_data_ptr<int32_t>() = symint_val;
828+
} else if (dtype == executorch::aten::ScalarType::Long) {
829+
*out_tensor.mutable_data_ptr<int64_t>() =
830+
static_cast<int64_t>(symint_val);
831+
}
832+
} else if (args[o]->isInt()) {
833+
*args[o] = EValue(static_cast<int64_t>(symint_val));
834+
}
792835
} else {
793836
VK_THROW(
794837
"Could not handle output with type ",

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,15 @@ ValueRef ComputeGraph::add_tensor(
452452
const utils::AxisMapLayout axis_map_layout) {
453453
ValueRef idx(static_cast<int>(values_.size()));
454454
check_no_active_value_ptrs();
455-
values_.emplace_back(api::vTensor(
456-
context(),
457-
sizes,
458-
dtype,
459-
storage_type,
460-
memory_layout,
461-
false,
462-
axis_map_layout));
455+
values_.emplace_back(
456+
api::vTensor(
457+
context(),
458+
sizes,
459+
dtype,
460+
storage_type,
461+
memory_layout,
462+
false,
463+
axis_map_layout));
463464

464465
if (shared_object_idx >= 0) {
465466
get_shared_object(shared_object_idx).add_user(this, idx);
@@ -725,6 +726,9 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
725726
}
726727

727728
int32_t ComputeGraph::read_symint(const ValueRef idx) {
729+
if (values_.at(idx).isInt()) {
730+
return static_cast<int32_t>(values_.at(idx).toInt());
731+
}
728732
return get_symint(idx)->get();
729733
}
730734

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ class ComputeGraph final {
573573
if (value.isBool()) {
574574
return static_cast<T>(value.toBool());
575575
}
576+
if (value.isSymInt()) {
577+
return utils::safe_downcast<T>(read_symint(idx));
578+
}
576579
VK_THROW("Cannot extract scalar from Value with type ", value.type());
577580
}
578581

0 commit comments

Comments
 (0)