@@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
671671 ComputeGraph* compute_graph = static_cast <ComputeGraph*>(handle);
672672
673673 const size_t num_inputs = compute_graph->inputs ().size ();
674+ const size_t num_outputs = compute_graph->outputs ().size ();
674675 bool should_propagate_resize = false ;
675676#ifdef ET_EVENT_TRACER_ENABLED
676677 runtime::EventTracer* event_tracer = context.event_tracer ();
@@ -690,22 +691,51 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
690691 for (size_t i = 0 ; i < num_inputs; i++) {
691692 const ValueRef iref = compute_graph->inputs ()[i].value ;
692693 if (compute_graph->val_is_tensor (iref)) {
693- VK_CHECK_COND (args[i]->isTensor ());
694- bool was_resized =
695- maybe_resize_input (compute_graph, i, args[i]->toTensor ());
696- should_propagate_resize = should_propagate_resize || was_resized;
697- compute_graph->maybe_cast_and_copy_into_staging (
698- compute_graph->inputs ()[i].staging ,
699- args[i]->toTensor ().const_data_ptr (),
700- args[i]->toTensor ().numel (),
701- equivalent_scalar_type (args[i]->toTensor ().scalar_type ()));
694+ if (args[i]->isTensor ()) {
695+ bool was_resized =
696+ maybe_resize_input (compute_graph, i, args[i]->toTensor ());
697+ should_propagate_resize = should_propagate_resize || was_resized;
698+ compute_graph->maybe_cast_and_copy_into_staging (
699+ compute_graph->inputs ()[i].staging ,
700+ args[i]->toTensor ().const_data_ptr (),
701+ args[i]->toTensor ().numel (),
702+ equivalent_scalar_type (args[i]->toTensor ().scalar_type ()));
703+ } else if (args[i]->isInt () || args[i]->isBool ()) {
704+ int64_t val =
705+ args[i]->isInt () ? args[i]->toInt () : (args[i]->toBool () ? 1 : 0 );
706+ vkapi::ScalarType tensor_dtype = compute_graph->dtype_of (iref);
707+ if (tensor_dtype == vkapi::kFloat ) {
708+ float fval = static_cast <float >(val);
709+ compute_graph->maybe_cast_and_copy_into_staging (
710+ compute_graph->inputs ()[i].staging , &fval, 1 , vkapi::kFloat );
711+ } else if (tensor_dtype == vkapi::kInt ) {
712+ int32_t ival = static_cast <int32_t >(val);
713+ compute_graph->maybe_cast_and_copy_into_staging (
714+ compute_graph->inputs ()[i].staging , &ival, 1 , vkapi::kInt );
715+ } else {
716+ compute_graph->maybe_cast_and_copy_into_staging (
717+ compute_graph->inputs ()[i].staging , &val, 1 , vkapi::kLong );
718+ }
719+ } else {
720+ VK_THROW (
721+ " Tensor input[" ,
722+ i,
723+ " ] has unsupported EValue tag " ,
724+ static_cast <int >(args[i]->tag ));
725+ }
702726 } else if (compute_graph->val_is_symint (iref)) {
703- VK_CHECK_COND (
704- args[i]->isTensor (),
705- " Cannot handle symint arg to graph that is not derived from a "
706- " scalar tensor at the moment." );
707- bool was_updated = maybe_update_scalar_tensor (
708- compute_graph, iref, args[i]->toTensor ());
727+ bool was_updated = false ;
728+ if (args[i]->isTensor ()) {
729+ was_updated = maybe_update_scalar_tensor (
730+ compute_graph, iref, args[i]->toTensor ());
731+ } else if (args[i]->isInt ()) {
732+ const int32_t new_val = static_cast <int32_t >(args[i]->toInt ());
733+ const int32_t cur_val = compute_graph->read_symint (iref);
734+ if (new_val != cur_val) {
735+ compute_graph->set_symint (iref, new_val);
736+ was_updated = true ;
737+ }
738+ }
709739 // Since symint inputs may impact tensor's sizes, trigger a resize if
710740 // any symbolic integer shapes are updated.
711741 should_propagate_resize = should_propagate_resize || was_updated;
@@ -770,14 +800,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
770800 " ETVK_COPY_OUTPUTS" ,
771801 /* delegate_debug_id = */ -1 );
772802#endif // ET_EVENT_TRACER_ENABLED
773- for (size_t i = 0 ; i < compute_graph->outputs ().size (); i++) {
774- const size_t o = i + num_inputs;
803+ const size_t output_offset = args.size () - num_outputs;
804+ for (size_t i = 0 ; i < num_outputs; i++) {
805+ const size_t o = output_offset + i;
775806 const ValueRef oref = compute_graph->outputs ()[i].value ;
776807 if (compute_graph->val_is_tensor (oref)) {
777808 VK_CHECK_COND (args[o]->isTensor ());
778809 maybe_resize_output (compute_graph, i, args[o]->toTensor ());
779- // args holds inputs directly followed by outputs, so the i'th output
780- // for compute_graph corresponds to the o'th arg
781810 compute_graph->maybe_cast_and_copy_from_staging (
782811 compute_graph->outputs ()[i].staging ,
783812 args[o]->toTensor ().mutable_data_ptr (),
@@ -789,6 +818,20 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
789818 // returned as an output, no action is required.
790819 else if (compute_graph->val_is_tref (oref)) {
791820 continue ;
821+ } else if (compute_graph->val_is_symint (oref)) {
822+ const int32_t symint_val = compute_graph->read_symint (oref);
823+ if (args[o]->isTensor ()) {
824+ executorch::aten::Tensor& out_tensor = args[o]->toTensor ();
825+ executorch::aten::ScalarType dtype = out_tensor.scalar_type ();
826+ if (dtype == executorch::aten::ScalarType::Int) {
827+ *out_tensor.mutable_data_ptr <int32_t >() = symint_val;
828+ } else if (dtype == executorch::aten::ScalarType::Long) {
829+ *out_tensor.mutable_data_ptr <int64_t >() =
830+ static_cast <int64_t >(symint_val);
831+ }
832+ } else if (args[o]->isInt ()) {
833+ *args[o] = EValue (static_cast <int64_t >(symint_val));
834+ }
792835 } else {
793836 VK_THROW (
794837 " Could not handle output with type " ,
0 commit comments