Skip to content
41 changes: 35 additions & 6 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def update_features_impl(op: OpKey):
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -279,6 +281,26 @@ def register_bitwise_and():
)


@update_features(exir_ops.edge.aten.bitwise_not.default)
def register_bitwise_not():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


@update_features(exir_ops.edge.aten.logical_and.default)
def register_logical_and():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# BinaryScalarOp.cpp
# =============================================================================
Expand All @@ -301,16 +323,22 @@ def register_pow_tensor_scalar():

@update_features(exir_ops.edge.aten._to_copy.default)
def register_to_copy():
def check_to_copy_node(node: torch.fx.Node) -> bool:
# Only single-arg _to_copy is supported
return len(node.args) == 1
def pick_to_copy_storage(
node: torch.fx.Node,
) -> Tuple[utils.TensorRepSet, utils.TensorRepSet]:
in_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr]
out_dtype = node.meta["val"].dtype
fp_types = {torch.float16, torch.float32}
if in_dtype in fp_types and out_dtype in fp_types:
return utils.ANY_STORAGE, utils.ANY_STORAGE
return utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
inputs_dtypes=utils.FP_INT_BOOL_T,
outputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
pick_io_storage_fn=pick_to_copy_storage,
)


Expand Down Expand Up @@ -1336,6 +1364,7 @@ def register_scalar_tensor():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
)


Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ unary_op:
OPERATOR: leaky_relu(X, A)
- NAME: round
OPERATOR: round(X)
- NAME: bitwise_not_uint8
OPERATOR: 1 - X
DTYPE: uint8
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.gt.Tensor, gt);
VK_REGISTER_OP(aten.ge.Tensor, ge);
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
VK_REGISTER_OP(aten.logical_and.default, bitwise_and);
}

} // namespace vkcompute
42 changes: 2 additions & 40 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,13 @@
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>

namespace vkcompute {

using utils::GPUMemoryLayout;
using utils::StorageType;

void resize_split_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
const ValueRef input = args.at(0).refs.at(0);
const ValueRef split_sizes_ref = args.at(1).refs.at(0);
const ValueRef dim_ref = args.at(2).refs.at(0);
const ValueRef out_list_ref = args.at(3).refs.at(0);

const ValueListPtr out_list = graph->get_value_list(out_list_ref);
const std::vector<int64_t> split_sizes =
*(graph->get_int_list(split_sizes_ref));
const int64_t dim = graph->extract_scalar<int64_t>(dim_ref);

const int64_t input_ndim = graph->dim_of(input);
const DimIndex dim_index = dim < 0 ? static_cast<DimIndex>(dim)
: static_cast<DimIndex>(dim - input_ndim);

std::vector<int64_t> input_sizes = graph->sizes_of(input);

for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
const int64_t split_size = split_sizes.at(split_idx);
const ValueRef out_ref = out_list->at(split_idx);

std::vector<int64_t> out_sizes = input_sizes;
out_sizes.at(dim_index) = split_size;

graph->virtual_resize(out_ref, out_sizes);
}
}

void add_split_node(
ComputeGraph& graph,
const ValueRef input,
Expand Down Expand Up @@ -125,7 +86,8 @@ void split_with_sizes_copy_default(
ValueRef out_list_ref = args[3];

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
std::vector<int64_t> split_sizes =
graph.extract_int_or_symint_list(split_sizes_ref);

add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref);
}
Expand Down
87 changes: 87 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,90 @@ void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
new ExecuteNode(resize_sym_add_node, args));
}

void sym_sub_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val - b_val;

graph->set_symint(out, result);
}

void resize_sym_sub_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_sub_impl(graph, resize_args);
}

void sym_sub(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_sub_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_sub_node, args));
}

void sym_floordiv_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
// Floor division: round towards negative infinity
const int32_t result = (a_val ^ b_val) < 0 && a_val % b_val != 0
? a_val / b_val - 1
: a_val / b_val;

graph->set_symint(out, result);
}

void resize_sym_floordiv_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_floordiv_impl(graph, resize_args);
}

void sym_floordiv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_floordiv_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_floordiv_node, args));
}

void sym_mul_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val * b_val;

graph->set_symint(out, result);
}

void resize_sym_mul_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_mul_impl(graph, resize_args);
}

void sym_mul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_mul_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_mul_node, args));
}

void select_as_symint_impl(
ComputeGraph* graph,
const std::vector<ArgGroup>& unused,
Expand Down Expand Up @@ -132,6 +216,9 @@ void select_as_symint(ComputeGraph& graph, const std::vector<ValueRef>& args) {
REGISTER_OPERATORS {
VK_REGISTER_OP(sym_size.int, sym_size_int);
VK_REGISTER_OP(add, sym_add);
VK_REGISTER_OP(sub, sym_sub);
VK_REGISTER_OP(floordiv, sym_floordiv);
VK_REGISTER_OP(mul, sym_mul);
VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint);
}

Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ bool is_float_type(vkapi::ScalarType dtype) {
}

void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
vkapi::ScalarType in_dtype = graph.dtype_of(in);
vkapi::ScalarType out_dtype = graph.dtype_of(out);
const vkapi::ScalarType in_dtype = graph.dtype_of(in);
const vkapi::ScalarType out_dtype = graph.dtype_of(out);

// Same-dtype or float<->half conversions can use BlitNode
if (in_dtype == out_dtype ||
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ DEFINE_ACTIVATION_FN(hardswish);
DEFINE_ACTIVATION_FN(hardsigmoid);
DEFINE_LEAKY_RELU_FN(leaky_relu);
DEFINE_ACTIVATION_FN(round);
DEFINE_ACTIVATION_FN(bitwise_not);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
Expand All @@ -179,6 +180,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
VK_REGISTER_OP(aten.round.default, round);
VK_REGISTER_OP(aten.bitwise_not.default, bitwise_not);
}

} // namespace vkcompute
18 changes: 15 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ void resize_where_node(
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef self = args.at(1).refs.at(1);

const std::vector<int64_t> self_sizes = graph->sizes_of(self);
graph->virtual_resize(out, self_sizes);
std::vector<int64_t> out_sizes;
for (const ValueRef ref : args.at(1).refs) {
if (!graph->val_is_tensor(ref)) {
continue;
}
const std::vector<int64_t> s = graph->sizes_of(ref);
if (s.size() > out_sizes.size()) {
out_sizes.resize(s.size(), 1);
}
const size_t offset = out_sizes.size() - s.size();
for (size_t i = 0; i < s.size(); i++) {
out_sizes[offset + i] = std::max(out_sizes[offset + i], s[i]);
}
}
graph->virtual_resize(out, out_sizes);
}

void add_where_node(
Expand Down
Loading