Skip to content

Commit 21e9fe3

Browse files
maxwbuckleyclaude
andauthored
[CoreML EP] Add GatherND builder (#28598)
### Summary New ML Program op builder: ONNX `GatherND` → CoreML `gather_nd`. - `batch_dims` must be 0 — the iOS15 `gather_nd` op has no `batch_dims` parameter; `IsOpSupportedImpl` rejects other values. - CoreML's `gather_nd` rejects a **bool `x`**, but transformer attention-mask graphs gather from bool tensors. For bool data the builder lowers the op as `cast(bool→int32) → gather_nd → cast(int32→bool)`; int32 represents 0/1 exactly, so the round-trip is lossless. - `validate_indices` is passed explicitly — the ML Program parser rejects `gather_nd` without it (the same quirk the `gather` builder works around). - ML-Program-only; `IsOpSupportedImpl` rejects the NeuralNetwork format. ### Indices handling (CoreML `gather_nd` quirks) Two CoreML behaviours that differ from ONNX are handled in the builder: - **`indices` must be a constant initializer.** CoreML's `gather_nd` miscomputes the result for some data/indices shape combinations when `indices` is a runtime (non-constant) input — it returns slice 0 regardless of the actual index value. With a constant `indices` it is correct, so non-constant cases fall back to CPU. Constant indices is also the common case (e.g. transformer attention masks). - **Negative indices are normalized at build time.** ONNX `GatherND` wraps a negative index by the corresponding data dim; CoreML's `gather_nd` does not and silently returns wrong values. Since `indices` is constant, the builder wraps any negatives into positive int32 indices while building the model (and requires the indexed data dims to be static, otherwise the node falls back to CPU). This was surfaced by fuzzing over randomized shapes/indices and verified on-device (negative indices, scalar outputs, ranks 2–4) against the CPU reference. ### Depends on the bool-Cast PR The bool-data `GatherND` test needs `Cast` as the `int ↔ bool` producer/consumer so the bool tensors stay internal to the CoreML partition (a partition cannot have bool I/O). This branch is **stacked on `coreml-cast-bool`** — the `cb43b7c75f` commit in this PR is the bool-Cast PR and drops from this diff once that one merges. ### Tests (`coreml_basic_test.cc`) - `GatherND_MLProgram` — a float `GatherND` runs on CoreML, matches CPU. - `GatherNDBoolData_MLProgram` — a `Cast → GatherND → Cast` bool chain runs fully on CoreML, exercising the cast round-trip lowering. - `GatherNDNeuralNetworkNotSupported` — `GatherND` falls back on the NeuralNetwork format. - `GatherNDBatchDimsNotSupported` — `GatherND` with `batch_dims=1` falls back to CPU. Doc: `coreml_supported_mlprogram_ops.md` lists `GatherND`. ### Series — CoreML EP coverage for transformer / diffusion graphs - #28595 — Support bool Cast in ML Program *(prerequisite)* - #28596 — Add Sin and Cos unary ops *(independent)* - #28597 — Add Where and And builders *(depends on #28595)* - **#28598 — Add GatherND builder** *(this PR — depends on #28595)* Together with #28278 (scalar-`Gather`), the series takes BERT / GPT-2 / ViT / diffusion-UNet graphs — tiny and full-size — from 2 CoreML partitions to 1, with zero graph breaks. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d113549 commit 21e9fe3

8 files changed

Lines changed: 614 additions & 2 deletions

File tree

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <optional>
5+
#include <vector>
6+
7+
#include "core/optimizer/initializer.h"
8+
#include "core/providers/coreml/builders/impl/base_op_builder.h"
9+
#include "core/providers/coreml/builders/impl/builder_utils.h"
10+
#include "core/providers/coreml/builders/model_builder.h"
11+
#include "core/providers/coreml/builders/op_builder_factory.h"
12+
#include "core/providers/coreml/shape_utils.h"
13+
#include "core/providers/shared/utils/utils.h"
14+
15+
namespace onnxruntime {
16+
namespace coreml {
17+
18+
// ONNX GatherND(data, indices) maps to the CoreML ML Program 'gather_nd' op.
19+
class GatherNDOpBuilder : public BaseOpBuilder {
20+
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
21+
22+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
23+
const logging::Logger& logger) const override;
24+
25+
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
26+
const logging::Logger& logger) const override;
27+
28+
bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
29+
const logging::Logger& logger) const override;
30+
31+
bool SupportsMLProgram() const override { return true; }
32+
};
33+
34+
Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
35+
const logging::Logger& logger) const {
36+
using namespace CoreML::Specification::MILSpec;
37+
const auto& input_defs = node.InputDefs();
38+
const auto& output_defs = node.OutputDefs();
39+
40+
// CoreML's gather_nd does not accept a bool 'x'. Transformer attention-mask
41+
// graphs gather from bool tensors, so for that case the op is composed as
42+
// cast(bool -> int32) -> gather_nd -> cast(int32 -> bool). int32 represents
43+
// 0/1 exactly, so the round-trip is lossless.
44+
int32_t data_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
45+
GetType(*input_defs[0], data_type, logger);
46+
const bool data_is_bool = data_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL;
47+
48+
std::string_view gather_x_name = input_defs[0]->Name();
49+
if (data_is_bool) {
50+
std::vector<int64_t> x_shape;
51+
const bool has_x_shape = GetShape(*input_defs[0], x_shape, logger);
52+
const std::string& cast_x_name = model_builder.GetUniqueName(node, "gather_nd_x_int32");
53+
std::unique_ptr<Operation> cast_in = model_builder.CreateOperation(node, "cast");
54+
AddOperationInput(*cast_in, "x", input_defs[0]->Name());
55+
AddOperationInput(*cast_in, "dtype",
56+
model_builder.AddScalarConstant(cast_in->type(), "dtype", std::string("int32")));
57+
AddIntermediateOperationOutput(*cast_in, cast_x_name, ONNX_NAMESPACE::TensorProto_DataType_INT32,
58+
has_x_shape ? std::optional<gsl::span<const int64_t>>(x_shape)
59+
: std::nullopt);
60+
model_builder.AddOperation(std::move(cast_in));
61+
gather_x_name = cast_x_name;
62+
}
63+
64+
// ONNX GatherND permits negative indices (wrapped by the corresponding data dim); CoreML's gather_nd
65+
// does not. The indices are a constant and the indexed data dims are static (both gated in
66+
// IsOpSupportedImpl), so wrap any negatives now and re-emit them as an int32 'indices' constant. The
67+
// original initializer is skipped (see AddInitializersToSkip).
68+
std::string indices_name;
69+
{
70+
std::vector<int64_t> data_shape, indices_shape;
71+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], data_shape, logger) &&
72+
GetShape(*input_defs[1], indices_shape, logger) && !indices_shape.empty(),
73+
"GatherND: failed to get data/indices shape");
74+
const size_t depth = static_cast<size_t>(indices_shape.back());
75+
const Initializer unpacked(*model_builder.GetConstantInitializer(input_defs[1]->Name()));
76+
int32_t indices_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
77+
GetType(*input_defs[1], indices_type, logger);
78+
79+
std::vector<int64_t> normalized;
80+
const auto wrap = [&](auto src) {
81+
normalized.reserve(src.size());
82+
for (size_t i = 0; i < src.size(); ++i) {
83+
int64_t v = static_cast<int64_t>(src[i]);
84+
if (v < 0) v += data_shape[i % depth];
85+
normalized.push_back(v);
86+
}
87+
};
88+
if (indices_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
89+
wrap(unpacked.DataAsSpan<int32_t>());
90+
} else {
91+
wrap(unpacked.DataAsSpan<int64_t>());
92+
}
93+
// AddConstant with int64 values emits an int32 'const' (CoreML uses int32 indices).
94+
indices_name = model_builder.AddConstant(node.OpType(), "indices", normalized,
95+
gsl::span<const int64_t>(indices_shape));
96+
}
97+
98+
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.scatter_gather.gather_nd
99+
// The iOS15 gather_nd has no batch_dims parameter and is equivalent to ONNX
100+
// GatherND with batch_dims == 0 (other values are gated in IsOpSupportedImpl).
101+
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather_nd");
102+
AddOperationInput(*op, "x", gather_x_name);
103+
AddOperationInput(*op, "indices", indices_name);
104+
// CoreML docs mark validate_indices as optional, but the ML Program parser
105+
// rejects gather_nd without it (same as the 'gather' op builder).
106+
AddOperationInput(*op, "validate_indices",
107+
model_builder.AddScalarConstant(op->type(), "validate_indices", false));
108+
109+
if (!data_is_bool) {
110+
AddOperationOutput(*op, *output_defs[0]);
111+
model_builder.AddOperation(std::move(op));
112+
return Status::OK();
113+
}
114+
115+
// Cast the int32 gather_nd result back to bool to match the ONNX output type.
116+
std::vector<int64_t> out_shape;
117+
const bool has_out_shape = GetShape(*output_defs[0], out_shape, logger);
118+
const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_nd_out_int32");
119+
AddIntermediateOperationOutput(*op, gather_out_name, ONNX_NAMESPACE::TensorProto_DataType_INT32,
120+
has_out_shape ? std::optional<gsl::span<const int64_t>>(out_shape)
121+
: std::nullopt);
122+
model_builder.AddOperation(std::move(op));
123+
124+
std::unique_ptr<Operation> cast_out = model_builder.CreateOperation(node, "cast");
125+
AddOperationInput(*cast_out, "x", gather_out_name);
126+
AddOperationInput(*cast_out, "dtype",
127+
model_builder.AddScalarConstant(cast_out->type(), "dtype", std::string("bool")));
128+
AddOperationOutput(*cast_out, *output_defs[0]);
129+
model_builder.AddOperation(std::move(cast_out));
130+
return Status::OK();
131+
}
132+
133+
bool GatherNDOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
134+
const logging::Logger& logger) const {
135+
if (!input_params.create_mlprogram) {
136+
LOGS(logger, VERBOSE) << "GatherND is only supported for the ML Program format.";
137+
return false;
138+
}
139+
140+
// The iOS15 gather_nd op has no batch_dims parameter, so only batch_dims == 0
141+
// (the ONNX default) maps directly.
142+
NodeAttrHelper helper(node);
143+
const auto batch_dims = helper.Get("batch_dims", int64_t{0});
144+
if (batch_dims != 0) {
145+
LOGS(logger, VERBOSE) << "GatherND only supports batch_dims == 0. Got: " << batch_dims;
146+
return false;
147+
}
148+
149+
// CoreML's gather_nd miscomputes the result for some data/indices shape combinations when 'indices'
150+
// is a non-constant (runtime) input -- it returns slice 0 regardless of the actual index value. With
151+
// a constant 'indices' the op is correct (verified on-device), and constant indices is the common case
152+
// (e.g. transformer attention-mask gathers). Require a constant 'indices' so we never silently emit
153+
// wrong results; non-constant cases fall back to CPU.
154+
if (!input_params.graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), /*check_outer_scope*/ true)) {
155+
LOGS(logger, VERBOSE) << "GatherND: 'indices' must be a constant initializer for the CoreML EP.";
156+
return false;
157+
}
158+
159+
// Negative indices are normalized to positive at build time (AddToModelBuilderImpl), which needs the
160+
// indexed data dims -- the first indices.shape[-1] dims -- to be statically known.
161+
std::vector<int64_t> data_shape, indices_shape;
162+
if (!GetShape(*node.InputDefs()[0], data_shape, logger) ||
163+
!GetShape(*node.InputDefs()[1], indices_shape, logger) || indices_shape.empty()) {
164+
LOGS(logger, VERBOSE) << "GatherND: data or indices shape is unknown.";
165+
return false;
166+
}
167+
const size_t depth = static_cast<size_t>(indices_shape.back());
168+
if (depth > data_shape.size()) {
169+
LOGS(logger, VERBOSE) << "GatherND: index tuple depth " << depth << " exceeds data rank " << data_shape.size();
170+
return false;
171+
}
172+
for (size_t k = 0; k < depth; ++k) {
173+
if (data_shape[k] < 0) {
174+
LOGS(logger, VERBOSE) << "GatherND: indexed data dims must be static.";
175+
return false;
176+
}
177+
}
178+
179+
return true;
180+
}
181+
182+
void GatherNDOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
183+
// 'indices' is re-emitted as a normalized int32 constant in AddToModelBuilderImpl, so skip the original.
184+
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
185+
}
186+
187+
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
188+
const logging::Logger& logger) const {
189+
const auto& input_defs = node.InputDefs();
190+
int32_t data_type = 0, indices_type = 0;
191+
if (!GetType(*input_defs[0], data_type, logger) || !GetType(*input_defs[1], indices_type, logger)) {
192+
return false;
193+
}
194+
195+
// gather_nd itself is type-agnostic over 'x' but rejects bool; bool 'data'
196+
// (transformer mask graphs) is supported via a cast round-trip in
197+
// AddToModelBuilderImpl. INT64 'data' is accepted because the CoreML EP
198+
// implicitly narrows int64 to int32 at the model boundary (the int64->int32
199+
// input conversion in model.mm and the matching INT32 feature/output handling
200+
// in ModelBuilder::RegisterModelInputOutput), so CoreML never sees int64.
201+
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
202+
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 &&
203+
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32 &&
204+
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
205+
data_type != ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
206+
LOGS(logger, VERBOSE) << "GatherND: 'data' input type not supported. Got type: " << data_type;
207+
return false;
208+
}
209+
210+
// ONNX GatherND indices are int64; the CoreML EP converts int64 <-> int32.
211+
if (indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
212+
indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
213+
LOGS(logger, VERBOSE) << "GatherND: 'indices' input must be int32 or int64. Got type: " << indices_type;
214+
return false;
215+
}
216+
return true;
217+
}
218+
219+
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
220+
op_registrations.builders.push_back(std::make_unique<GatherNDOpBuilder>());
221+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
222+
}
223+
224+
} // namespace coreml
225+
} // namespace onnxruntime

onnxruntime/core/providers/coreml/builders/model_builder.cc

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,12 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
917917
AddInt64Output(name);
918918
}
919919
break;
920+
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
921+
// ArrayFeatureType has no bool, so (like int64) the external feature is INT32. The int32<->bool
922+
// cast at the ML Program boundary is wired up below / in RewriteBoolGraphIOBoundaries(), and the
923+
// runtime int32<->bool data conversion is handled in model.mm.
924+
multi_array->set_datatype(ArrayFeatureType::INT32);
925+
break;
920926
default: {
921927
// TODO: support other type
922928
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -932,22 +938,123 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
932938
return Status::OK();
933939
}
934940

941+
const bool is_bool = data_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL;
942+
935943
if (create_ml_program_) {
936944
if (is_input) {
937945
// the model inputs need to be wired up as args to the 'main' function.
938946
auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true);
939947

940-
// Handle conversion from int64 to int32
948+
// Handle conversion from int64 to int32. A bool feature is exposed as int32 too, so the function
949+
// arg is int32; the int32->bool cast is inserted immediately below so the op builders see bool.
941950
tensor_value_type.mutable_type()->mutable_tensortype()->set_datatype(
942-
OnnxDataTypeToMILSpec(data_type));
951+
OnnxDataTypeToMILSpec(is_bool ? ONNX_NAMESPACE::TensorProto_DataType_INT32 : data_type));
943952

944953
tensor_value_type.set_name(name);
945954

946955
mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type));
956+
957+
if (is_bool) {
958+
// Emit the int32->bool cast now (ahead of any consumer in the block). Consumers still reference
959+
// `name`; RewriteBoolGraphIOBoundaries() repoints them at the bool value once they've been added.
960+
const std::string bool_name = GetUniqueName(name + "_to_bool");
961+
AddBoundaryCastOp(name, bool_name, ONNX_NAMESPACE::TensorProto_DataType_BOOL, shape);
962+
bool_input_value_rename_[name] = bool_name;
963+
}
947964
} else {
948965
// the model outputs need to be set as outputs of the Block for the 'main' function
949966
*mlprogram_main_block_->mutable_outputs()->Add() = name;
967+
968+
if (is_bool) {
969+
// The op builders produce a bool value named `name`; RewriteBoolGraphIOBoundaries() inserts a
970+
// bool->int32 cast so the int32 feature/block-output `name` is satisfied.
971+
bool_graph_outputs_.emplace_back(name, shape);
972+
}
973+
}
974+
}
975+
976+
return Status::OK();
977+
}
978+
979+
void ModelBuilder::AddBoundaryCastOp(std::string_view input_value_name, std::string_view output_value_name,
980+
int32_t output_onnx_type, gsl::span<const int64_t> shape) {
981+
auto op = std::make_unique<MILSpec::Operation>();
982+
op->set_type("cast");
983+
(*op->mutable_attributes())["name"] =
984+
CreateScalarTensorValue(GetUniqueName(MakeString("boundary_cast_", output_value_name)));
985+
986+
AddOperationInput(*op, "x", input_value_name);
987+
const std::string mil_dtype =
988+
output_onnx_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "bool" : "int32";
989+
AddOperationInput(*op, "dtype", AddScalarConstant(op->type(), "dtype", mil_dtype));
990+
AddIntermediateOperationOutput(*op, output_value_name, output_onnx_type, shape);
991+
992+
AddOperation(std::move(op));
993+
}
994+
995+
Status ModelBuilder::RewriteBoolGraphIOBoundaries() {
996+
if (bool_input_value_rename_.empty() && bool_graph_outputs_.empty()) {
997+
return Status::OK();
998+
}
999+
1000+
// bool graph inputs: the int32->bool cast was already emitted (ahead of consumers) in
1001+
// RegisterModelInputOutput. Repoint each consumer at the bool value. The cast ops themselves
1002+
// legitimately reference the original int32 input, so skip any op whose output is a rename target.
1003+
if (!bool_input_value_rename_.empty()) {
1004+
std::unordered_set<std::string> cast_outputs;
1005+
for (const auto& [orig, bool_name] : bool_input_value_rename_) {
1006+
cast_outputs.insert(bool_name);
1007+
}
1008+
for (auto& op : *mlprogram_main_block_->mutable_operations()) {
1009+
bool is_boundary_cast = false;
1010+
for (const auto& out : op.outputs()) {
1011+
if (Contains(cast_outputs, out.name())) {
1012+
is_boundary_cast = true;
1013+
break;
1014+
}
1015+
}
1016+
if (is_boundary_cast) {
1017+
continue;
1018+
}
1019+
for (auto& input : *op.mutable_inputs()) {
1020+
for (auto& arg : *input.second.mutable_arguments()) {
1021+
auto it = bool_input_value_rename_.find(arg.name());
1022+
if (it != bool_input_value_rename_.end()) {
1023+
arg.set_name(it->second);
1024+
}
1025+
}
1026+
}
1027+
}
1028+
}
1029+
1030+
// bool graph outputs: the op builders produced a bool value named `name`. Rename that producer's output
1031+
// (and any internal consumers) to a bool intermediate, then append a bool->int32 cast producing the
1032+
// int32 feature/block-output `name`.
1033+
for (const auto& [name, shape] : bool_graph_outputs_) {
1034+
const std::string pre_name = GetUniqueName(name + "_from_bool");
1035+
bool found = false;
1036+
for (auto& op : *mlprogram_main_block_->mutable_operations()) {
1037+
for (auto& out : *op.mutable_outputs()) {
1038+
if (out.name() == name) {
1039+
out.set_name(pre_name);
1040+
found = true;
1041+
}
1042+
}
1043+
}
1044+
if (!found) {
1045+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
1046+
"RewriteBoolGraphIOBoundaries: bool graph output not produced by any operation: ", name);
1047+
}
1048+
for (auto& op : *mlprogram_main_block_->mutable_operations()) {
1049+
for (auto& input : *op.mutable_inputs()) {
1050+
for (auto& arg : *input.second.mutable_arguments()) {
1051+
if (arg.name() == name) {
1052+
arg.set_name(pre_name);
1053+
}
1054+
}
1055+
}
9501056
}
1057+
AddBoundaryCastOp(pre_name, name, ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
9511058
}
9521059

9531060
return Status::OK();
@@ -994,6 +1101,7 @@ Status ModelBuilder::CreateModel() {
9941101
ORT_RETURN_IF_ERROR(RegisterModelOutputs());
9951102

9961103
if (create_ml_program_) {
1104+
ORT_RETURN_IF_ERROR(RewriteBoolGraphIOBoundaries());
9971105
SanitizeNames();
9981106
}
9991107

0 commit comments

Comments
 (0)