Skip to content
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model

bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (input_params.create_mlprogram) {
// The ML Program 'cast' op stands alone, so a Cast fed directly by a graph
// input (no preceding node) is fine here.
return true;
}

// The NeuralNetwork path only supports a Cast that consumes an ArgMax, so it
// needs a preceding node to inspect (InputEdgesBegin() must be dereferenceable).
if (node.GetInputEdgesCount() == 0) {
LOGS(logger, VERBOSE) << "Cast has no preceding nodes.";
return false;
}

if (input_params.create_mlprogram) {
return true;
}

const auto& prec_node = node.InputEdgesBegin()->GetNode();

/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
Expand Down Expand Up @@ -141,11 +145,13 @@ bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] co
if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) &&
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL) &&
(output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL)) {
return true;
} else {
LOGS(logger, VERBOSE) << "[" << node.OpType()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <optional>
#include <vector>

#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/shared/utils/utils.h"

namespace onnxruntime {
namespace coreml {

// ONNX GatherND(data, indices) maps to the CoreML ML Program 'gather_nd' op.
class GatherNDOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

bool SupportsMLProgram() const override { return true; }
};

Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
using namespace CoreML::Specification::MILSpec;
const auto& input_defs = node.InputDefs();
const auto& output_defs = node.OutputDefs();

// CoreML's gather_nd does not accept a bool 'x'. Transformer attention-mask
// graphs gather from bool tensors, so for that case the op is composed as
// cast(bool -> int32) -> gather_nd -> cast(int32 -> bool). int32 represents
// 0/1 exactly, so the round-trip is lossless.
int32_t data_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
GetType(*input_defs[0], data_type, logger);
const bool data_is_bool = data_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL;

std::string_view gather_x_name = input_defs[0]->Name();
if (data_is_bool) {
std::vector<int64_t> x_shape;
const bool has_x_shape = GetShape(*input_defs[0], x_shape, logger);
const std::string& cast_x_name = model_builder.GetUniqueName(node, "gather_nd_x_int32");
std::unique_ptr<Operation> cast_in = model_builder.CreateOperation(node, "cast");
AddOperationInput(*cast_in, "x", input_defs[0]->Name());
AddOperationInput(*cast_in, "dtype",
model_builder.AddScalarConstant(cast_in->type(), "dtype", std::string("int32")));
AddIntermediateOperationOutput(*cast_in, cast_x_name, ONNX_NAMESPACE::TensorProto_DataType_INT32,
has_x_shape ? std::optional<gsl::span<const int64_t>>(x_shape)
: std::nullopt);
model_builder.AddOperation(std::move(cast_in));
gather_x_name = cast_x_name;
}

// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.scatter_gather.gather_nd
// The iOS15 gather_nd has no batch_dims parameter and is equivalent to ONNX
// GatherND with batch_dims == 0 (other values are gated in IsOpSupportedImpl).
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather_nd");
AddOperationInput(*op, "x", gather_x_name);
AddOperationInput(*op, "indices", input_defs[1]->Name());
// CoreML docs mark validate_indices as optional, but the ML Program parser
// rejects gather_nd without it (same as the 'gather' op builder).
AddOperationInput(*op, "validate_indices",
model_builder.AddScalarConstant(op->type(), "validate_indices", false));

if (!data_is_bool) {
AddOperationOutput(*op, *output_defs[0]);
model_builder.AddOperation(std::move(op));
return Status::OK();
}

// Cast the int32 gather_nd result back to bool to match the ONNX output type.
std::vector<int64_t> out_shape;
const bool has_out_shape = GetShape(*output_defs[0], out_shape, logger);
const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_nd_out_int32");
AddIntermediateOperationOutput(*op, gather_out_name, ONNX_NAMESPACE::TensorProto_DataType_INT32,
has_out_shape ? std::optional<gsl::span<const int64_t>>(out_shape)
: std::nullopt);
model_builder.AddOperation(std::move(op));

std::unique_ptr<Operation> cast_out = model_builder.CreateOperation(node, "cast");
AddOperationInput(*cast_out, "x", gather_out_name);
AddOperationInput(*cast_out, "dtype",
model_builder.AddScalarConstant(cast_out->type(), "dtype", std::string("bool")));
AddOperationOutput(*cast_out, *output_defs[0]);
model_builder.AddOperation(std::move(cast_out));
return Status::OK();
}

bool GatherNDOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (!input_params.create_mlprogram) {
LOGS(logger, VERBOSE) << "GatherND is only supported for the ML Program format.";
return false;
}

// The iOS15 gather_nd op has no batch_dims parameter, so only batch_dims == 0
// (the ONNX default) maps directly.
NodeAttrHelper helper(node);
const auto batch_dims = helper.Get("batch_dims", int64_t{0});
if (batch_dims != 0) {
LOGS(logger, VERBOSE) << "GatherND only supports batch_dims == 0. Got: " << batch_dims;
return false;
}
return true;
}

bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
int32_t data_type = 0, indices_type = 0;
if (!GetType(*input_defs[0], data_type, logger) || !GetType(*input_defs[1], indices_type, logger)) {
return false;
}

// gather_nd itself is type-agnostic over 'x' but rejects bool; bool 'data'
// (transformer mask graphs) is supported via a cast round-trip in
// AddToModelBuilderImpl.
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32 &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
LOGS(logger, VERBOSE) << "GatherND: 'data' input type not supported. Got type: " << data_type;
return false;
}

// ONNX GatherND indices are int64; the CoreML EP converts int64 <-> int32.
if (indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
LOGS(logger, VERBOSE) << "GatherND: 'indices' input must be int32 or int64. Got type: " << indices_type;
return false;
}
return true;
}

void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GatherNDOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace coreml
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations);
CreateFlattenOpBuilder("Flatten", op_registrations);
CreateGatherOpBuilder("Gather", op_registrations);
CreateGatherNDOpBuilder("GatherND", op_registrations);
CreateGemmOpBuilder("Gemm", op_registrations);
CreateGridSampleOpBuilder("GridSample", op_registrations);
CreateIdentityOpBuilder("Identity", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void CreateConvTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrat
void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateIdentityOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Loading